mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(assistants/main.py): add assistants api streaming support
This commit is contained in:
parent
7b474ec267
commit
f3d78532f9
9 changed files with 444 additions and 65 deletions
|
@ -905,6 +905,14 @@ async def arun_thread(
|
|||
)
|
||||
|
||||
|
||||
def run_thread_stream(
|
||||
*,
|
||||
event_handler: Optional[AssistantEventHandler] = None,
|
||||
**kwargs,
|
||||
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||
return run_thread(stream=True, event_handler=event_handler, **kwargs) # type: ignore
|
||||
|
||||
|
||||
def run_thread(
|
||||
custom_llm_provider: Literal["openai", "azure"],
|
||||
thread_id: str,
|
||||
|
@ -916,6 +924,7 @@ def run_thread(
|
|||
stream: Optional[bool] = None,
|
||||
tools: Optional[Iterable[AssistantToolParam]] = None,
|
||||
client: Optional[Any] = None,
|
||||
event_handler: Optional[AssistantEventHandler] = None, # for stream=True calls
|
||||
**kwargs,
|
||||
) -> Run:
|
||||
"""Run a given thread + assistant."""
|
||||
|
@ -959,6 +968,7 @@ def run_thread(
|
|||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_assistants_api.run_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
|
@ -975,6 +985,7 @@ def run_thread(
|
|||
organization=organization,
|
||||
client=client,
|
||||
arun_thread=arun_thread,
|
||||
event_handler=event_handler,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = (
|
||||
|
|
|
@ -31,6 +31,10 @@ from ..types.llms.openai import (
|
|||
Thread,
|
||||
AssistantToolParam,
|
||||
Run,
|
||||
AssistantEventHandler,
|
||||
AsyncAssistantEventHandler,
|
||||
AsyncAssistantStreamManager,
|
||||
AssistantStreamManager,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1975,27 +1979,67 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
)
|
||||
|
||||
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
thread_id="thread_OHLZkEj5xJLxdk0REZ4cl9sP",
|
||||
assistant_id="asst_nIzr656D1GIVMLHOKD76bN2T",
|
||||
additional_instructions=None,
|
||||
instructions=None,
|
||||
metadata=None,
|
||||
model=None,
|
||||
tools=None,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
# thread_id=thread_id,
|
||||
# assistant_id=assistant_id,
|
||||
# additional_instructions=additional_instructions,
|
||||
# instructions=instructions,
|
||||
# metadata=metadata,
|
||||
# model=model,
|
||||
# tools=tools,
|
||||
# )
|
||||
|
||||
return response
|
||||
|
||||
async def async_run_thread_stream(
|
||||
self,
|
||||
client: AsyncAzureOpenAI,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"additional_instructions": additional_instructions,
|
||||
"instructions": instructions,
|
||||
"metadata": metadata,
|
||||
"model": model,
|
||||
"tools": tools,
|
||||
}
|
||||
if event_handler is not None:
|
||||
data["event_handler"] = event_handler
|
||||
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||
|
||||
def run_thread_stream(
|
||||
self,
|
||||
client: AzureOpenAI,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"additional_instructions": additional_instructions,
|
||||
"instructions": instructions,
|
||||
"metadata": metadata,
|
||||
"model": model,
|
||||
"tools": tools,
|
||||
}
|
||||
if event_handler is not None:
|
||||
data["event_handler"] = event_handler
|
||||
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
|
@ -2062,8 +2106,30 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client=None,
|
||||
arun_thread=None,
|
||||
event_handler: Optional[AssistantEventHandler] = None,
|
||||
):
|
||||
if arun_thread is not None and arun_thread == True:
|
||||
if stream is not None and stream == True:
|
||||
azure_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
return self.async_run_thread_stream(
|
||||
client=azure_client,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
)
|
||||
return self.arun_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
|
@ -2091,6 +2157,19 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
client=client,
|
||||
)
|
||||
|
||||
if stream is not None and stream == True:
|
||||
return self.run_thread_stream(
|
||||
client=openai_client,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
|
|
|
@ -2534,6 +2534,56 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
|
||||
return response
|
||||
|
||||
async def async_run_thread_stream(
|
||||
self,
|
||||
client: AsyncOpenAI,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"additional_instructions": additional_instructions,
|
||||
"instructions": instructions,
|
||||
"metadata": metadata,
|
||||
"model": model,
|
||||
"tools": tools,
|
||||
}
|
||||
if event_handler is not None:
|
||||
data["event_handler"] = event_handler
|
||||
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||
|
||||
def run_thread_stream(
|
||||
self,
|
||||
client: OpenAI,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"additional_instructions": additional_instructions,
|
||||
"instructions": instructions,
|
||||
"metadata": metadata,
|
||||
"model": model,
|
||||
"tools": tools,
|
||||
}
|
||||
if event_handler is not None:
|
||||
data["event_handler"] = event_handler
|
||||
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
|
@ -2554,6 +2604,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
organization: Optional[str],
|
||||
client,
|
||||
arun_thread: Literal[True],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> Coroutine[None, None, Run]:
|
||||
...
|
||||
|
||||
|
@ -2575,6 +2626,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
organization: Optional[str],
|
||||
client,
|
||||
arun_thread: Optional[Literal[False]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> Run:
|
||||
...
|
||||
|
||||
|
@ -2597,8 +2649,29 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
organization: Optional[str],
|
||||
client=None,
|
||||
arun_thread=None,
|
||||
event_handler: Optional[AssistantEventHandler] = None,
|
||||
):
|
||||
if arun_thread is not None and arun_thread == True:
|
||||
if stream is not None and stream == True:
|
||||
_client = self.async_get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
return self.async_run_thread_stream(
|
||||
client=_client,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
)
|
||||
return self.arun_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
|
@ -2624,6 +2697,19 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
client=client,
|
||||
)
|
||||
|
||||
if stream is not None and stream == True:
|
||||
return self.run_thread_stream(
|
||||
client=openai_client,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
|
|
|
@ -40,8 +40,14 @@ model_list:
|
|||
vertex_project: my-project-9d5c
|
||||
vertex_location: us-central1
|
||||
|
||||
assistant_settings:
|
||||
custom_llm_provider: openai
|
||||
litellm_params:
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
|
||||
router_settings:
|
||||
enable_pre_call_checks: true
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["langfuse"]
|
||||
success_callback: ["langfuse"]
|
||||
|
|
|
@ -117,7 +117,13 @@ import pydantic
|
|||
from litellm.proxy._types import *
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm.router import LiteLLM_Params, Deployment, updateDeployment, ModelGroupInfo
|
||||
from litellm.router import (
|
||||
LiteLLM_Params,
|
||||
Deployment,
|
||||
updateDeployment,
|
||||
ModelGroupInfo,
|
||||
AssistantsTypedDict,
|
||||
)
|
||||
from litellm.router import ModelInfo as RouterModelInfo
|
||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
|
@ -2853,6 +2859,17 @@ class ProxyConfig:
|
|||
if "ollama" in litellm_model_name and litellm_model_api_base is None:
|
||||
run_ollama_serve()
|
||||
|
||||
## ASSISTANT SETTINGS
|
||||
assistants_config: Optional[AssistantsTypedDict] = None
|
||||
assistant_settings = config.get("assistant_settings", None)
|
||||
if assistant_settings:
|
||||
for k, v in assistant_settings["litellm_params"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
_v = v.replace("os.environ/", "")
|
||||
v = os.getenv(_v)
|
||||
assistant_settings["litellm_params"][k] = v
|
||||
assistants_config = AssistantsTypedDict(**assistant_settings) # type: ignore
|
||||
|
||||
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
||||
router_settings = config.get("router_settings", None)
|
||||
if router_settings and isinstance(router_settings, dict):
|
||||
|
@ -2868,7 +2885,9 @@ class ProxyConfig:
|
|||
for k, v in router_settings.items():
|
||||
if k in available_args:
|
||||
router_params[k] = v
|
||||
router = litellm.Router(**router_params) # type:ignore
|
||||
router = litellm.Router(
|
||||
**router_params, assistants_config=assistants_config
|
||||
) # type:ignore
|
||||
return router, model_list, general_settings
|
||||
|
||||
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
||||
|
@ -3606,6 +3625,60 @@ def data_generator(response):
|
|||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
|
||||
async def async_assistants_data_generator(
|
||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||
):
|
||||
verbose_proxy_logger.debug("inside generator")
|
||||
try:
|
||||
start_time = time.time()
|
||||
async with response as chunk:
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=chunk
|
||||
)
|
||||
|
||||
# chunk = chunk.model_dump_json(exclude_none=True)
|
||||
async for c in chunk:
|
||||
c = c.model_dump_json(exclude_none=True)
|
||||
try:
|
||||
yield f"data: {c}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data: {str(e)}\n\n"
|
||||
|
||||
# Streaming is done, yield the [DONE] chunk
|
||||
done_message = "[DONE]"
|
||||
yield f"data: {done_message}\n\n"
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=request_data,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
if user_debug:
|
||||
traceback.print_exc()
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
|
||||
proxy_exception = ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
error_returned = json.dumps({"error": proxy_exception.to_dict()})
|
||||
yield f"data: {error_returned}\n\n"
|
||||
|
||||
|
||||
async def async_data_generator(
|
||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||
):
|
||||
|
@ -5347,7 +5420,6 @@ async def get_assistants(
|
|||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = { # type: ignore
|
||||
|
@ -5405,9 +5477,7 @@ async def get_assistants(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.aget_assistants(
|
||||
custom_llm_provider="openai", client=None, **data
|
||||
)
|
||||
response = await llm_router.aget_assistants(**data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -5479,7 +5549,6 @@ async def create_threads(
|
|||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = { # type: ignore
|
||||
|
@ -5537,9 +5606,7 @@ async def create_threads(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.acreate_thread(
|
||||
custom_llm_provider="openai", client=None, **data
|
||||
)
|
||||
response = await llm_router.acreate_thread(**data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -5667,9 +5734,7 @@ async def get_thread(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.aget_thread(
|
||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||
)
|
||||
response = await llm_router.aget_thread(thread_id=thread_id, **data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -5800,9 +5865,7 @@ async def add_messages(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.a_add_message(
|
||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||
)
|
||||
response = await llm_router.a_add_message(thread_id=thread_id, **data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -5929,9 +5992,7 @@ async def get_messages(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.aget_messages(
|
||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||
)
|
||||
response = await llm_router.aget_messages(thread_id=thread_id, **data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -6060,9 +6121,19 @@ async def run_thread(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.arun_thread(
|
||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||
)
|
||||
response = await llm_router.arun_thread(thread_id=thread_id, **data)
|
||||
|
||||
if (
|
||||
"stream" in data and data["stream"] == True
|
||||
): # use generate_responses to stream responses
|
||||
return StreamingResponse(
|
||||
async_assistants_data_generator(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
response=response,
|
||||
request_data=data,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
|
|
@ -51,6 +51,7 @@ from litellm.types.router import (
|
|||
AlertingConfig,
|
||||
DeploymentTypedDict,
|
||||
ModelGroupInfo,
|
||||
AssistantsTypedDict,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
|
@ -78,6 +79,8 @@ class Router:
|
|||
def __init__(
|
||||
self,
|
||||
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
|
||||
## ASSISTANTS API ##
|
||||
assistants_config: Optional[AssistantsTypedDict] = None,
|
||||
## CACHING ##
|
||||
redis_url: Optional[str] = None,
|
||||
redis_host: Optional[str] = None,
|
||||
|
@ -212,6 +215,7 @@ class Router:
|
|||
elif debug_level == "DEBUG":
|
||||
verbose_router_logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.assistants_config = assistants_config
|
||||
self.deployment_names: List = (
|
||||
[]
|
||||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||
|
@ -1831,31 +1835,56 @@ class Router:
|
|||
|
||||
async def aget_assistants(
|
||||
self,
|
||||
custom_llm_provider: Literal["openai"],
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> AsyncCursorPage[Assistant]:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.aget_assistants(
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
async def acreate_thread(
|
||||
self,
|
||||
custom_llm_provider: Literal["openai"],
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> Thread:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
return await litellm.acreate_thread(
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
async def aget_thread(
|
||||
self,
|
||||
custom_llm_provider: Literal["openai"],
|
||||
thread_id: str,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> Thread:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
return await litellm.aget_thread(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
|
@ -1865,15 +1894,24 @@ class Router:
|
|||
|
||||
async def a_add_message(
|
||||
self,
|
||||
custom_llm_provider: Literal["openai"],
|
||||
thread_id: str,
|
||||
role: Literal["user", "assistant"],
|
||||
content: str,
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> OpenAIMessage:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.a_add_message(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
|
@ -1887,11 +1925,19 @@ class Router:
|
|||
|
||||
async def aget_messages(
|
||||
self,
|
||||
custom_llm_provider: Literal["openai"],
|
||||
thread_id: str,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> AsyncCursorPage[OpenAIMessage]:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
return await litellm.aget_messages(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
|
@ -1901,9 +1947,9 @@ class Router:
|
|||
|
||||
async def arun_thread(
|
||||
self,
|
||||
custom_llm_provider: Literal["openai"],
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
additional_instructions: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
|
@ -1913,6 +1959,16 @@ class Router:
|
|||
client: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> Run:
|
||||
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.arun_thread(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
|
|
|
@ -18,7 +18,9 @@ from litellm.llms.openai import (
|
|||
OpenAIMessage as Message,
|
||||
AsyncCursorPage,
|
||||
SyncCursorPage,
|
||||
AssistantEventHandler,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
"""
|
||||
V0 Scope:
|
||||
|
@ -129,10 +131,46 @@ async def test_add_message_litellm(sync_mode, provider):
|
|||
assert isinstance(added_message, Message)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["openai", "azure"])
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
class EventHandler(AssistantEventHandler):
|
||||
@override
|
||||
def on_text_created(self, text) -> None:
|
||||
print(f"\nassistant > ", end="", flush=True)
|
||||
|
||||
@override
|
||||
def on_text_delta(self, delta, snapshot):
|
||||
print(delta.value, end="", flush=True)
|
||||
|
||||
def on_tool_call_created(self, tool_call):
|
||||
print(f"\nassistant > {tool_call.type}\n", flush=True)
|
||||
|
||||
def on_tool_call_delta(self, delta, snapshot):
|
||||
if delta.type == "code_interpreter":
|
||||
if delta.code_interpreter.input:
|
||||
print(delta.code_interpreter.input, end="", flush=True)
|
||||
if delta.code_interpreter.outputs:
|
||||
print(f"\n\noutput >", flush=True)
|
||||
for output in delta.code_interpreter.outputs:
|
||||
if output.type == "logs":
|
||||
print(f"\n{output.logs}", flush=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider",
|
||||
[
|
||||
"azure",
|
||||
"openai",
|
||||
],
|
||||
) #
|
||||
@pytest.mark.parametrize(
|
||||
"sync_mode",
|
||||
[False, True],
|
||||
) #
|
||||
@pytest.mark.parametrize(
|
||||
"is_streaming",
|
||||
[True, False],
|
||||
) #
|
||||
@pytest.mark.asyncio
|
||||
async def test_aarun_thread_litellm(sync_mode, provider):
|
||||
async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
|
||||
"""
|
||||
- Get Assistants
|
||||
- Create thread
|
||||
|
@ -163,27 +201,48 @@ async def test_aarun_thread_litellm(sync_mode, provider):
|
|||
if sync_mode:
|
||||
added_message = litellm.add_message(**data)
|
||||
|
||||
run = litellm.run_thread(assistant_id=assistant_id, **data)
|
||||
|
||||
if run.status == "completed":
|
||||
messages = litellm.get_messages(
|
||||
thread_id=_new_thread.id, custom_llm_provider=provider
|
||||
if is_streaming:
|
||||
run = litellm.run_thread_stream(
|
||||
assistant_id=assistant_id, event_handler=EventHandler(), **data
|
||||
)
|
||||
assert isinstance(messages.data[0], Message)
|
||||
with run as run:
|
||||
assert isinstance(run, AssistantEventHandler)
|
||||
print(run)
|
||||
run.until_done()
|
||||
else:
|
||||
pytest.fail("An unexpected error occurred when running the thread")
|
||||
run = litellm.run_thread(
|
||||
assistant_id=assistant_id, stream=is_streaming, **data
|
||||
)
|
||||
if run.status == "completed":
|
||||
messages = litellm.get_messages(
|
||||
thread_id=_new_thread.id, custom_llm_provider=provider
|
||||
)
|
||||
assert isinstance(messages.data[0], Message)
|
||||
else:
|
||||
pytest.fail("An unexpected error occurred when running the thread")
|
||||
|
||||
else:
|
||||
added_message = await litellm.a_add_message(**data)
|
||||
|
||||
run = await litellm.arun_thread(
|
||||
custom_llm_provider=provider, thread_id=thread_id, assistant_id=assistant_id
|
||||
)
|
||||
|
||||
if run.status == "completed":
|
||||
messages = await litellm.aget_messages(
|
||||
thread_id=_new_thread.id, custom_llm_provider=provider
|
||||
if is_streaming:
|
||||
run = litellm.run_thread_stream(
|
||||
assistant_id=assistant_id, event_handler=EventHandler(), **data
|
||||
)
|
||||
assert isinstance(messages.data[0], Message)
|
||||
with run as run:
|
||||
assert isinstance(run, AssistantEventHandler)
|
||||
print(run)
|
||||
run.until_done()
|
||||
else:
|
||||
pytest.fail("An unexpected error occurred when running the thread")
|
||||
run = await litellm.arun_thread(
|
||||
custom_llm_provider=provider,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
)
|
||||
|
||||
if run.status == "completed":
|
||||
messages = await litellm.aget_messages(
|
||||
thread_id=_new_thread.id, custom_llm_provider=provider
|
||||
)
|
||||
assert isinstance(messages.data[0], Message)
|
||||
else:
|
||||
pytest.fail("An unexpected error occurred when running the thread")
|
||||
|
|
|
@ -13,6 +13,12 @@ from openai.types.beta.threads.message import Message as OpenAIMessage
|
|||
from openai.types.beta.thread_create_params import (
|
||||
Message as OpenAICreateThreadParamsMessage,
|
||||
)
|
||||
from openai.lib.streaming._assistants import (
|
||||
AssistantEventHandler,
|
||||
AssistantStreamManager,
|
||||
AsyncAssistantStreamManager,
|
||||
AsyncAssistantEventHandler,
|
||||
)
|
||||
from openai.types.beta.assistant_tool_param import AssistantToolParam
|
||||
from openai.types.beta.threads.run import Run
|
||||
from openai.types.beta.assistant import Assistant
|
||||
|
|
|
@ -446,3 +446,8 @@ class ModelGroupInfo(BaseModel):
|
|||
supports_vision: bool = Field(default=False)
|
||||
supports_function_calling: bool = Field(default=False)
|
||||
supported_openai_params: List[str] = Field(default=[])
|
||||
|
||||
|
||||
class AssistantsTypedDict(TypedDict):
|
||||
custom_llm_provider: Literal["azure", "openai"]
|
||||
litellm_params: LiteLLMParamsTypedDict
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue