feat(assistants/main.py): add assistants api streaming support

This commit is contained in:
Krrish Dholakia 2024-06-04 16:30:35 -07:00
parent 7b474ec267
commit f3d78532f9
9 changed files with 444 additions and 65 deletions

View file

@ -905,6 +905,14 @@ async def arun_thread(
)
def run_thread_stream(
*,
event_handler: Optional[AssistantEventHandler] = None,
**kwargs,
) -> AssistantStreamManager[AssistantEventHandler]:
return run_thread(stream=True, event_handler=event_handler, **kwargs) # type: ignore
def run_thread(
custom_llm_provider: Literal["openai", "azure"],
thread_id: str,
@ -916,6 +924,7 @@ def run_thread(
stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[Any] = None,
event_handler: Optional[AssistantEventHandler] = None, # for stream=True calls
**kwargs,
) -> Run:
"""Run a given thread + assistant."""
@ -959,6 +968,7 @@ def run_thread(
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.run_thread(
thread_id=thread_id,
assistant_id=assistant_id,
@ -975,6 +985,7 @@ def run_thread(
organization=organization,
client=client,
arun_thread=arun_thread,
event_handler=event_handler,
)
elif custom_llm_provider == "azure":
api_base = (

View file

@ -31,6 +31,10 @@ from ..types.llms.openai import (
Thread,
AssistantToolParam,
Run,
AssistantEventHandler,
AsyncAssistantEventHandler,
AsyncAssistantStreamManager,
AssistantStreamManager,
)
@ -1975,27 +1979,67 @@ class AzureAssistantsAPI(BaseLLM):
)
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id="thread_OHLZkEj5xJLxdk0REZ4cl9sP",
assistant_id="asst_nIzr656D1GIVMLHOKD76bN2T",
additional_instructions=None,
instructions=None,
metadata=None,
model=None,
tools=None,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
# response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
# thread_id=thread_id,
# assistant_id=assistant_id,
# additional_instructions=additional_instructions,
# instructions=instructions,
# metadata=metadata,
# model=model,
# tools=tools,
# )
return response
async def async_run_thread_stream(
self,
client: AsyncAzureOpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
def run_thread_stream(
self,
client: AzureOpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AssistantStreamManager[AssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
# fmt: off
@overload
@ -2062,8 +2106,30 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None,
):
if arun_thread is not None and arun_thread == True:
if stream is not None and stream == True:
azure_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
return self.async_run_thread_stream(
client=azure_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
@ -2091,6 +2157,19 @@ class AzureAssistantsAPI(BaseLLM):
client=client,
)
if stream is not None and stream == True:
return self.run_thread_stream(
client=openai_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,

View file

@ -2534,6 +2534,56 @@ class OpenAIAssistantsAPI(BaseLLM):
return response
async def async_run_thread_stream(
self,
client: AsyncOpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
def run_thread_stream(
self,
client: OpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AssistantStreamManager[AssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
# fmt: off
@overload
@ -2554,6 +2604,7 @@ class OpenAIAssistantsAPI(BaseLLM):
organization: Optional[str],
client,
arun_thread: Literal[True],
event_handler: Optional[AssistantEventHandler],
) -> Coroutine[None, None, Run]:
...
@ -2575,6 +2626,7 @@ class OpenAIAssistantsAPI(BaseLLM):
organization: Optional[str],
client,
arun_thread: Optional[Literal[False]],
event_handler: Optional[AssistantEventHandler],
) -> Run:
...
@ -2597,8 +2649,29 @@ class OpenAIAssistantsAPI(BaseLLM):
organization: Optional[str],
client=None,
arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None,
):
if arun_thread is not None and arun_thread == True:
if stream is not None and stream == True:
_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
return self.async_run_thread_stream(
client=_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
@ -2624,6 +2697,19 @@ class OpenAIAssistantsAPI(BaseLLM):
client=client,
)
if stream is not None and stream == True:
return self.run_thread_stream(
client=openai_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,

View file

@ -40,8 +40,14 @@ model_list:
vertex_project: my-project-9d5c
vertex_location: us-central1
assistant_settings:
custom_llm_provider: openai
litellm_params:
api_key: os.environ/OPENAI_API_KEY
router_settings:
enable_pre_call_checks: true
litellm_settings:
success_callback: ["langfuse"]
success_callback: ["langfuse"]

View file

@ -117,7 +117,13 @@ import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache
from litellm.proxy.health_check import perform_health_check
from litellm.router import LiteLLM_Params, Deployment, updateDeployment, ModelGroupInfo
from litellm.router import (
LiteLLM_Params,
Deployment,
updateDeployment,
ModelGroupInfo,
AssistantsTypedDict,
)
from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler
@ -2853,6 +2859,17 @@ class ProxyConfig:
if "ollama" in litellm_model_name and litellm_model_api_base is None:
run_ollama_serve()
## ASSISTANT SETTINGS
assistants_config: Optional[AssistantsTypedDict] = None
assistant_settings = config.get("assistant_settings", None)
if assistant_settings:
for k, v in assistant_settings["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
_v = v.replace("os.environ/", "")
v = os.getenv(_v)
assistant_settings["litellm_params"][k] = v
assistants_config = AssistantsTypedDict(**assistant_settings) # type: ignore
## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None)
if router_settings and isinstance(router_settings, dict):
@ -2868,7 +2885,9 @@ class ProxyConfig:
for k, v in router_settings.items():
if k in available_args:
router_params[k] = v
router = litellm.Router(**router_params) # type:ignore
router = litellm.Router(
**router_params, assistants_config=assistants_config
) # type:ignore
return router, model_list, general_settings
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
@ -3606,6 +3625,60 @@ def data_generator(response):
yield f"data: {json.dumps(chunk)}\n\n"
async def async_assistants_data_generator(
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
):
verbose_proxy_logger.debug("inside generator")
try:
start_time = time.time()
async with response as chunk:
### CALL HOOKS ### - modify outgoing data
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
user_api_key_dict=user_api_key_dict, response=chunk
)
# chunk = chunk.model_dump_json(exclude_none=True)
async for c in chunk:
c = c.model_dump_json(exclude_none=True)
try:
yield f"data: {c}\n\n"
except Exception as e:
yield f"data: {str(e)}\n\n"
# Streaming is done, yield the [DONE] chunk
done_message = "[DONE]"
yield f"data: {done_message}\n\n"
except Exception as e:
traceback.print_exc()
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=request_data,
)
verbose_proxy_logger.debug(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
)
router_model_names = llm_router.model_names if llm_router is not None else []
if user_debug:
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
proxy_exception = ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
error_returned = json.dumps({"error": proxy_exception.to_dict()})
yield f"data: {error_returned}\n\n"
async def async_data_generator(
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
):
@ -5347,7 +5420,6 @@ async def get_assistants(
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
@ -5405,9 +5477,7 @@ async def get_assistants(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_assistants(
custom_llm_provider="openai", client=None, **data
)
response = await llm_router.aget_assistants(**data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -5479,7 +5549,6 @@ async def create_threads(
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
@ -5537,9 +5606,7 @@ async def create_threads(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.acreate_thread(
custom_llm_provider="openai", client=None, **data
)
response = await llm_router.acreate_thread(**data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -5667,9 +5734,7 @@ async def get_thread(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_thread(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
response = await llm_router.aget_thread(thread_id=thread_id, **data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -5800,9 +5865,7 @@ async def add_messages(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.a_add_message(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
response = await llm_router.a_add_message(thread_id=thread_id, **data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -5929,9 +5992,7 @@ async def get_messages(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_messages(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
response = await llm_router.aget_messages(thread_id=thread_id, **data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -6060,9 +6121,19 @@ async def run_thread(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.arun_thread(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
response = await llm_router.arun_thread(thread_id=thread_id, **data)
if (
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
return StreamingResponse(
async_assistants_data_generator(
user_api_key_dict=user_api_key_dict,
response=response,
request_data=data,
),
media_type="text/event-stream",
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting

View file

@ -51,6 +51,7 @@ from litellm.types.router import (
AlertingConfig,
DeploymentTypedDict,
ModelGroupInfo,
AssistantsTypedDict,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
@ -78,6 +79,8 @@ class Router:
def __init__(
self,
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
## ASSISTANTS API ##
assistants_config: Optional[AssistantsTypedDict] = None,
## CACHING ##
redis_url: Optional[str] = None,
redis_host: Optional[str] = None,
@ -212,6 +215,7 @@ class Router:
elif debug_level == "DEBUG":
verbose_router_logger.setLevel(logging.DEBUG)
self.assistants_config = assistants_config
self.deployment_names: List = (
[]
) # names of models under litellm_params. ex. azure/chatgpt-v-2
@ -1831,31 +1835,56 @@ class Router:
async def aget_assistants(
self,
custom_llm_provider: Literal["openai"],
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[Assistant]:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_assistants(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def acreate_thread(
self,
custom_llm_provider: Literal["openai"],
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.acreate_thread(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def aget_thread(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
@ -1865,15 +1894,24 @@ class Router:
async def a_add_message(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
role: Literal["user", "assistant"],
content: str,
attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> OpenAIMessage:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.a_add_message(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
@ -1887,11 +1925,19 @@ class Router:
async def aget_messages(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[OpenAIMessage]:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_messages(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
@ -1901,9 +1947,9 @@ class Router:
async def arun_thread(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
assistant_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
additional_instructions: Optional[str] = None,
instructions: Optional[str] = None,
metadata: Optional[dict] = None,
@ -1913,6 +1959,16 @@ class Router:
client: Optional[Any] = None,
**kwargs,
) -> Run:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.arun_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,

View file

@ -18,7 +18,9 @@ from litellm.llms.openai import (
OpenAIMessage as Message,
AsyncCursorPage,
SyncCursorPage,
AssistantEventHandler,
)
from typing_extensions import override
"""
V0 Scope:
@ -129,10 +131,46 @@ async def test_add_message_litellm(sync_mode, provider):
assert isinstance(added_message, Message)
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize("sync_mode", [True, False])
class EventHandler(AssistantEventHandler):
@override
def on_text_created(self, text) -> None:
print(f"\nassistant > ", end="", flush=True)
@override
def on_text_delta(self, delta, snapshot):
print(delta.value, end="", flush=True)
def on_tool_call_created(self, tool_call):
print(f"\nassistant > {tool_call.type}\n", flush=True)
def on_tool_call_delta(self, delta, snapshot):
if delta.type == "code_interpreter":
if delta.code_interpreter.input:
print(delta.code_interpreter.input, end="", flush=True)
if delta.code_interpreter.outputs:
print(f"\n\noutput >", flush=True)
for output in delta.code_interpreter.outputs:
if output.type == "logs":
print(f"\n{output.logs}", flush=True)
@pytest.mark.parametrize(
"provider",
[
"azure",
"openai",
],
) #
@pytest.mark.parametrize(
"sync_mode",
[False, True],
) #
@pytest.mark.parametrize(
"is_streaming",
[True, False],
) #
@pytest.mark.asyncio
async def test_aarun_thread_litellm(sync_mode, provider):
async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
"""
- Get Assistants
- Create thread
@ -163,27 +201,48 @@ async def test_aarun_thread_litellm(sync_mode, provider):
if sync_mode:
added_message = litellm.add_message(**data)
run = litellm.run_thread(assistant_id=assistant_id, **data)
if run.status == "completed":
messages = litellm.get_messages(
thread_id=_new_thread.id, custom_llm_provider=provider
if is_streaming:
run = litellm.run_thread_stream(
assistant_id=assistant_id, event_handler=EventHandler(), **data
)
assert isinstance(messages.data[0], Message)
with run as run:
assert isinstance(run, AssistantEventHandler)
print(run)
run.until_done()
else:
pytest.fail("An unexpected error occurred when running the thread")
run = litellm.run_thread(
assistant_id=assistant_id, stream=is_streaming, **data
)
if run.status == "completed":
messages = litellm.get_messages(
thread_id=_new_thread.id, custom_llm_provider=provider
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail("An unexpected error occurred when running the thread")
else:
added_message = await litellm.a_add_message(**data)
run = await litellm.arun_thread(
custom_llm_provider=provider, thread_id=thread_id, assistant_id=assistant_id
)
if run.status == "completed":
messages = await litellm.aget_messages(
thread_id=_new_thread.id, custom_llm_provider=provider
if is_streaming:
run = litellm.run_thread_stream(
assistant_id=assistant_id, event_handler=EventHandler(), **data
)
assert isinstance(messages.data[0], Message)
with run as run:
assert isinstance(run, AssistantEventHandler)
print(run)
run.until_done()
else:
pytest.fail("An unexpected error occurred when running the thread")
run = await litellm.arun_thread(
custom_llm_provider=provider,
thread_id=thread_id,
assistant_id=assistant_id,
)
if run.status == "completed":
messages = await litellm.aget_messages(
thread_id=_new_thread.id, custom_llm_provider=provider
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail("An unexpected error occurred when running the thread")

View file

@ -13,6 +13,12 @@ from openai.types.beta.threads.message import Message as OpenAIMessage
from openai.types.beta.thread_create_params import (
Message as OpenAICreateThreadParamsMessage,
)
from openai.lib.streaming._assistants import (
AssistantEventHandler,
AssistantStreamManager,
AsyncAssistantStreamManager,
AsyncAssistantEventHandler,
)
from openai.types.beta.assistant_tool_param import AssistantToolParam
from openai.types.beta.threads.run import Run
from openai.types.beta.assistant import Assistant

View file

@ -446,3 +446,8 @@ class ModelGroupInfo(BaseModel):
supports_vision: bool = Field(default=False)
supports_function_calling: bool = Field(default=False)
supported_openai_params: List[str] = Field(default=[])
class AssistantsTypedDict(TypedDict):
custom_llm_provider: Literal["azure", "openai"]
litellm_params: LiteLLMParamsTypedDict