Merge pull request #4012 from BerriAI/litellm_assistants_streaming

feat(assistants/main.py): add assistants api streaming support
This commit is contained in:
Krish Dholakia 2024-06-04 20:58:02 -07:00 committed by GitHub
commit 7127e6bcd8
11 changed files with 458 additions and 75 deletions

View file

@ -135,10 +135,17 @@ print(f"run_thread: {run_thread}")
</TabItem>
<TabItem value="proxy" label="PROXY">
```bash
$ export OPENAI_API_KEY="sk-..."
```yaml
assistant_settings:
custom_llm_provider: azure
litellm_params:
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
api_version: os.environ/AZURE_API_VERSION
```
$ litellm
```bash
$ litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```

View file

@ -846,6 +846,15 @@ def get_messages(
### RUNS ###
def arun_thread_stream(
*,
event_handler: Optional[AssistantEventHandler] = None,
**kwargs,
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
kwargs["arun_thread"] = True
return run_thread(stream=True, event_handler=event_handler, **kwargs) # type: ignore
async def arun_thread(
custom_llm_provider: Literal["openai", "azure"],
thread_id: str,
@ -905,6 +914,14 @@ async def arun_thread(
)
def run_thread_stream(
*,
event_handler: Optional[AssistantEventHandler] = None,
**kwargs,
) -> AssistantStreamManager[AssistantEventHandler]:
return run_thread(stream=True, event_handler=event_handler, **kwargs) # type: ignore
def run_thread(
custom_llm_provider: Literal["openai", "azure"],
thread_id: str,
@ -916,6 +933,7 @@ def run_thread(
stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[Any] = None,
event_handler: Optional[AssistantEventHandler] = None, # for stream=True calls
**kwargs,
) -> Run:
"""Run a given thread + assistant."""
@ -959,6 +977,7 @@ def run_thread(
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.run_thread(
thread_id=thread_id,
assistant_id=assistant_id,
@ -975,6 +994,7 @@ def run_thread(
organization=organization,
client=client,
arun_thread=arun_thread,
event_handler=event_handler,
)
elif custom_llm_provider == "azure":
api_base = (

View file

@ -31,6 +31,10 @@ from ..types.llms.openai import (
Thread,
AssistantToolParam,
Run,
AssistantEventHandler,
AsyncAssistantEventHandler,
AsyncAssistantStreamManager,
AssistantStreamManager,
)
@ -1975,27 +1979,67 @@ class AzureAssistantsAPI(BaseLLM):
)
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id="thread_OHLZkEj5xJLxdk0REZ4cl9sP",
assistant_id="asst_nIzr656D1GIVMLHOKD76bN2T",
additional_instructions=None,
instructions=None,
metadata=None,
model=None,
tools=None,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
# response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
# thread_id=thread_id,
# assistant_id=assistant_id,
# additional_instructions=additional_instructions,
# instructions=instructions,
# metadata=metadata,
# model=model,
# tools=tools,
# )
return response
def async_run_thread_stream(
self,
client: AsyncAzureOpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
def run_thread_stream(
self,
client: AzureOpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AssistantStreamManager[AssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
# fmt: off
@overload
@ -2062,8 +2106,30 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None,
):
if arun_thread is not None and arun_thread == True:
if stream is not None and stream == True:
azure_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
return self.async_run_thread_stream(
client=azure_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
@ -2091,6 +2157,19 @@ class AzureAssistantsAPI(BaseLLM):
client=client,
)
if stream is not None and stream == True:
return self.run_thread_stream(
client=openai_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,

View file

@ -2534,6 +2534,56 @@ class OpenAIAssistantsAPI(BaseLLM):
return response
def async_run_thread_stream(
self,
client: AsyncOpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
def run_thread_stream(
self,
client: OpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AssistantStreamManager[AssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
# fmt: off
@overload
@ -2554,6 +2604,7 @@ class OpenAIAssistantsAPI(BaseLLM):
organization: Optional[str],
client,
arun_thread: Literal[True],
event_handler: Optional[AssistantEventHandler],
) -> Coroutine[None, None, Run]:
...
@ -2575,6 +2626,7 @@ class OpenAIAssistantsAPI(BaseLLM):
organization: Optional[str],
client,
arun_thread: Optional[Literal[False]],
event_handler: Optional[AssistantEventHandler],
) -> Run:
...
@ -2597,8 +2649,29 @@ class OpenAIAssistantsAPI(BaseLLM):
organization: Optional[str],
client=None,
arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None,
):
if arun_thread is not None and arun_thread == True:
if stream is not None and stream == True:
_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
return self.async_run_thread_stream(
client=_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
@ -2624,6 +2697,19 @@ class OpenAIAssistantsAPI(BaseLLM):
client=client,
)
if stream is not None and stream == True:
return self.run_thread_stream(
client=openai_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,

View file

@ -487,9 +487,16 @@ class PredibaseChatCompletion(BaseLLM):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
try:
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
except httpx.HTTPStatusError as e:
raise PredibaseError(
status_code=e.response.status_code, message=e.response.text
)
except Exception as e:
raise PredibaseError(status_code=500, message=str(e))
return self.process_response(
model=model,
response=response,

View file

@ -40,8 +40,14 @@ model_list:
vertex_project: my-project-9d5c
vertex_location: us-central1
assistant_settings:
custom_llm_provider: openai
litellm_params:
api_key: os.environ/OPENAI_API_KEY
router_settings:
enable_pre_call_checks: true
litellm_settings:
success_callback: ["langfuse"]
success_callback: ["langfuse"]

View file

@ -117,7 +117,13 @@ import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache
from litellm.proxy.health_check import perform_health_check
from litellm.router import LiteLLM_Params, Deployment, updateDeployment, ModelGroupInfo
from litellm.router import (
LiteLLM_Params,
Deployment,
updateDeployment,
ModelGroupInfo,
AssistantsTypedDict,
)
from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler
@ -2859,6 +2865,17 @@ class ProxyConfig:
if "ollama" in litellm_model_name and litellm_model_api_base is None:
run_ollama_serve()
## ASSISTANT SETTINGS
assistants_config: Optional[AssistantsTypedDict] = None
assistant_settings = config.get("assistant_settings", None)
if assistant_settings:
for k, v in assistant_settings["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
_v = v.replace("os.environ/", "")
v = os.getenv(_v)
assistant_settings["litellm_params"][k] = v
assistants_config = AssistantsTypedDict(**assistant_settings) # type: ignore
## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None)
if router_settings and isinstance(router_settings, dict):
@ -2874,7 +2891,9 @@ class ProxyConfig:
for k, v in router_settings.items():
if k in available_args:
router_params[k] = v
router = litellm.Router(**router_params) # type:ignore
router = litellm.Router(
**router_params, assistants_config=assistants_config
) # type:ignore
return router, model_list, general_settings
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
@ -3612,6 +3631,60 @@ def data_generator(response):
yield f"data: {json.dumps(chunk)}\n\n"
async def async_assistants_data_generator(
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
):
verbose_proxy_logger.debug("inside generator")
try:
start_time = time.time()
async with response as chunk:
### CALL HOOKS ### - modify outgoing data
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
user_api_key_dict=user_api_key_dict, response=chunk
)
# chunk = chunk.model_dump_json(exclude_none=True)
async for c in chunk:
c = c.model_dump_json(exclude_none=True)
try:
yield f"data: {c}\n\n"
except Exception as e:
yield f"data: {str(e)}\n\n"
# Streaming is done, yield the [DONE] chunk
done_message = "[DONE]"
yield f"data: {done_message}\n\n"
except Exception as e:
traceback.print_exc()
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=request_data,
)
verbose_proxy_logger.debug(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
)
router_model_names = llm_router.model_names if llm_router is not None else []
if user_debug:
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
proxy_exception = ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
error_returned = json.dumps({"error": proxy_exception.to_dict()})
yield f"data: {error_returned}\n\n"
async def async_data_generator(
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
):
@ -5353,7 +5426,6 @@ async def get_assistants(
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
@ -5411,9 +5483,7 @@ async def get_assistants(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_assistants(
custom_llm_provider="openai", client=None, **data
)
response = await llm_router.aget_assistants(**data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -5485,7 +5555,6 @@ async def create_threads(
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
@ -5543,9 +5612,7 @@ async def create_threads(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.acreate_thread(
custom_llm_provider="openai", client=None, **data
)
response = await llm_router.acreate_thread(**data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -5673,9 +5740,7 @@ async def get_thread(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_thread(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
response = await llm_router.aget_thread(thread_id=thread_id, **data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -5806,9 +5871,7 @@ async def add_messages(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.a_add_message(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
response = await llm_router.a_add_message(thread_id=thread_id, **data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -5935,9 +5998,7 @@ async def get_messages(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_messages(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
response = await llm_router.aget_messages(thread_id=thread_id, **data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -5984,12 +6045,12 @@ async def get_messages(
)
@router.get(
@router.post(
"/v1/threads/{thread_id}/runs",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.get(
@router.post(
"/threads/{thread_id}/runs",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
@ -6066,9 +6127,19 @@ async def run_thread(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.arun_thread(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
response = await llm_router.arun_thread(thread_id=thread_id, **data)
if (
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
return StreamingResponse(
async_assistants_data_generator(
user_api_key_dict=user_api_key_dict,
response=response,
request_data=data,
),
media_type="text/event-stream",
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting

View file

@ -51,6 +51,7 @@ from litellm.types.router import (
AlertingConfig,
DeploymentTypedDict,
ModelGroupInfo,
AssistantsTypedDict,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
@ -78,6 +79,8 @@ class Router:
def __init__(
self,
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
## ASSISTANTS API ##
assistants_config: Optional[AssistantsTypedDict] = None,
## CACHING ##
redis_url: Optional[str] = None,
redis_host: Optional[str] = None,
@ -212,6 +215,7 @@ class Router:
elif debug_level == "DEBUG":
verbose_router_logger.setLevel(logging.DEBUG)
self.assistants_config = assistants_config
self.deployment_names: List = (
[]
) # names of models under litellm_params. ex. azure/chatgpt-v-2
@ -1831,31 +1835,56 @@ class Router:
async def aget_assistants(
self,
custom_llm_provider: Literal["openai"],
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[Assistant]:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_assistants(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def acreate_thread(
self,
custom_llm_provider: Literal["openai"],
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.acreate_thread(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def aget_thread(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
@ -1865,15 +1894,24 @@ class Router:
async def a_add_message(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
role: Literal["user", "assistant"],
content: str,
attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> OpenAIMessage:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.a_add_message(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
@ -1887,11 +1925,19 @@ class Router:
async def aget_messages(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[OpenAIMessage]:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_messages(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
@ -1901,9 +1947,9 @@ class Router:
async def arun_thread(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
assistant_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
additional_instructions: Optional[str] = None,
instructions: Optional[str] = None,
metadata: Optional[dict] = None,
@ -1913,6 +1959,16 @@ class Router:
client: Optional[Any] = None,
**kwargs,
) -> Run:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.arun_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,

View file

@ -18,7 +18,10 @@ from litellm.llms.openai import (
OpenAIMessage as Message,
AsyncCursorPage,
SyncCursorPage,
AssistantEventHandler,
AsyncAssistantEventHandler,
)
from typing_extensions import override
"""
V0 Scope:
@ -129,10 +132,26 @@ async def test_add_message_litellm(sync_mode, provider):
assert isinstance(added_message, Message)
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"provider",
[
"azure",
"openai",
],
) #
@pytest.mark.parametrize(
"sync_mode",
[
True,
False,
],
)
@pytest.mark.parametrize(
"is_streaming",
[True, False],
) #
@pytest.mark.asyncio
async def test_aarun_thread_litellm(sync_mode, provider):
async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
"""
- Get Assistants
- Create thread
@ -163,27 +182,48 @@ async def test_aarun_thread_litellm(sync_mode, provider):
if sync_mode:
added_message = litellm.add_message(**data)
run = litellm.run_thread(assistant_id=assistant_id, **data)
if run.status == "completed":
messages = litellm.get_messages(
thread_id=_new_thread.id, custom_llm_provider=provider
)
assert isinstance(messages.data[0], Message)
if is_streaming:
run = litellm.run_thread_stream(assistant_id=assistant_id, **data)
with run as run:
assert isinstance(run, AssistantEventHandler)
print(run)
run.until_done()
else:
pytest.fail("An unexpected error occurred when running the thread")
run = litellm.run_thread(
assistant_id=assistant_id, stream=is_streaming, **data
)
if run.status == "completed":
messages = litellm.get_messages(
thread_id=_new_thread.id, custom_llm_provider=provider
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail("An unexpected error occurred when running the thread")
else:
added_message = await litellm.a_add_message(**data)
run = await litellm.arun_thread(
custom_llm_provider=provider, thread_id=thread_id, assistant_id=assistant_id
)
if run.status == "completed":
messages = await litellm.aget_messages(
thread_id=_new_thread.id, custom_llm_provider=provider
)
assert isinstance(messages.data[0], Message)
if is_streaming:
run = litellm.arun_thread_stream(assistant_id=assistant_id, **data)
async with run as run:
print(f"run: {run}")
assert isinstance(
run,
AsyncAssistantEventHandler,
)
print(run)
run.until_done()
else:
pytest.fail("An unexpected error occurred when running the thread")
run = await litellm.arun_thread(
custom_llm_provider=provider,
thread_id=thread_id,
assistant_id=assistant_id,
)
if run.status == "completed":
messages = await litellm.aget_messages(
thread_id=_new_thread.id, custom_llm_provider=provider
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail("An unexpected error occurred when running the thread")

View file

@ -13,6 +13,12 @@ from openai.types.beta.threads.message import Message as OpenAIMessage
from openai.types.beta.thread_create_params import (
Message as OpenAICreateThreadParamsMessage,
)
from openai.lib.streaming._assistants import (
AssistantEventHandler,
AssistantStreamManager,
AsyncAssistantStreamManager,
AsyncAssistantEventHandler,
)
from openai.types.beta.assistant_tool_param import AssistantToolParam
from openai.types.beta.threads.run import Run
from openai.types.beta.assistant import Assistant

View file

@ -446,3 +446,8 @@ class ModelGroupInfo(BaseModel):
supports_vision: bool = Field(default=False)
supports_function_calling: bool = Field(default=False)
supported_openai_params: List[str] = Field(default=[])
class AssistantsTypedDict(TypedDict):
custom_llm_provider: Literal["azure", "openai"]
litellm_params: LiteLLMParamsTypedDict