mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
Merge pull request #4012 from BerriAI/litellm_assistants_streaming
feat(assistants/main.py): add assistants api streaming support
This commit is contained in:
commit
7127e6bcd8
11 changed files with 458 additions and 75 deletions
|
@ -135,10 +135,17 @@ print(f"run_thread: {run_thread}")
|
|||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
```bash
|
||||
$ export OPENAI_API_KEY="sk-..."
|
||||
```yaml
|
||||
assistant_settings:
|
||||
custom_llm_provider: azure
|
||||
litellm_params:
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_version: os.environ/AZURE_API_VERSION
|
||||
```
|
||||
|
||||
$ litellm
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
|
|
@ -846,6 +846,15 @@ def get_messages(
|
|||
|
||||
|
||||
### RUNS ###
|
||||
def arun_thread_stream(
|
||||
*,
|
||||
event_handler: Optional[AssistantEventHandler] = None,
|
||||
**kwargs,
|
||||
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
||||
kwargs["arun_thread"] = True
|
||||
return run_thread(stream=True, event_handler=event_handler, **kwargs) # type: ignore
|
||||
|
||||
|
||||
async def arun_thread(
|
||||
custom_llm_provider: Literal["openai", "azure"],
|
||||
thread_id: str,
|
||||
|
@ -905,6 +914,14 @@ async def arun_thread(
|
|||
)
|
||||
|
||||
|
||||
def run_thread_stream(
|
||||
*,
|
||||
event_handler: Optional[AssistantEventHandler] = None,
|
||||
**kwargs,
|
||||
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||
return run_thread(stream=True, event_handler=event_handler, **kwargs) # type: ignore
|
||||
|
||||
|
||||
def run_thread(
|
||||
custom_llm_provider: Literal["openai", "azure"],
|
||||
thread_id: str,
|
||||
|
@ -916,6 +933,7 @@ def run_thread(
|
|||
stream: Optional[bool] = None,
|
||||
tools: Optional[Iterable[AssistantToolParam]] = None,
|
||||
client: Optional[Any] = None,
|
||||
event_handler: Optional[AssistantEventHandler] = None, # for stream=True calls
|
||||
**kwargs,
|
||||
) -> Run:
|
||||
"""Run a given thread + assistant."""
|
||||
|
@ -959,6 +977,7 @@ def run_thread(
|
|||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_assistants_api.run_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
|
@ -975,6 +994,7 @@ def run_thread(
|
|||
organization=organization,
|
||||
client=client,
|
||||
arun_thread=arun_thread,
|
||||
event_handler=event_handler,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = (
|
||||
|
|
|
@ -31,6 +31,10 @@ from ..types.llms.openai import (
|
|||
Thread,
|
||||
AssistantToolParam,
|
||||
Run,
|
||||
AssistantEventHandler,
|
||||
AsyncAssistantEventHandler,
|
||||
AsyncAssistantStreamManager,
|
||||
AssistantStreamManager,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1975,27 +1979,67 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
)
|
||||
|
||||
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
thread_id="thread_OHLZkEj5xJLxdk0REZ4cl9sP",
|
||||
assistant_id="asst_nIzr656D1GIVMLHOKD76bN2T",
|
||||
additional_instructions=None,
|
||||
instructions=None,
|
||||
metadata=None,
|
||||
model=None,
|
||||
tools=None,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
# thread_id=thread_id,
|
||||
# assistant_id=assistant_id,
|
||||
# additional_instructions=additional_instructions,
|
||||
# instructions=instructions,
|
||||
# metadata=metadata,
|
||||
# model=model,
|
||||
# tools=tools,
|
||||
# )
|
||||
|
||||
return response
|
||||
|
||||
def async_run_thread_stream(
|
||||
self,
|
||||
client: AsyncAzureOpenAI,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"additional_instructions": additional_instructions,
|
||||
"instructions": instructions,
|
||||
"metadata": metadata,
|
||||
"model": model,
|
||||
"tools": tools,
|
||||
}
|
||||
if event_handler is not None:
|
||||
data["event_handler"] = event_handler
|
||||
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||
|
||||
def run_thread_stream(
|
||||
self,
|
||||
client: AzureOpenAI,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"additional_instructions": additional_instructions,
|
||||
"instructions": instructions,
|
||||
"metadata": metadata,
|
||||
"model": model,
|
||||
"tools": tools,
|
||||
}
|
||||
if event_handler is not None:
|
||||
data["event_handler"] = event_handler
|
||||
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
|
@ -2062,8 +2106,30 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client=None,
|
||||
arun_thread=None,
|
||||
event_handler: Optional[AssistantEventHandler] = None,
|
||||
):
|
||||
if arun_thread is not None and arun_thread == True:
|
||||
if stream is not None and stream == True:
|
||||
azure_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
return self.async_run_thread_stream(
|
||||
client=azure_client,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
)
|
||||
return self.arun_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
|
@ -2091,6 +2157,19 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
client=client,
|
||||
)
|
||||
|
||||
if stream is not None and stream == True:
|
||||
return self.run_thread_stream(
|
||||
client=openai_client,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
|
|
|
@ -2534,6 +2534,56 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
|
||||
return response
|
||||
|
||||
def async_run_thread_stream(
|
||||
self,
|
||||
client: AsyncOpenAI,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"additional_instructions": additional_instructions,
|
||||
"instructions": instructions,
|
||||
"metadata": metadata,
|
||||
"model": model,
|
||||
"tools": tools,
|
||||
}
|
||||
if event_handler is not None:
|
||||
data["event_handler"] = event_handler
|
||||
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||
|
||||
def run_thread_stream(
|
||||
self,
|
||||
client: OpenAI,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"additional_instructions": additional_instructions,
|
||||
"instructions": instructions,
|
||||
"metadata": metadata,
|
||||
"model": model,
|
||||
"tools": tools,
|
||||
}
|
||||
if event_handler is not None:
|
||||
data["event_handler"] = event_handler
|
||||
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
|
@ -2554,6 +2604,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
organization: Optional[str],
|
||||
client,
|
||||
arun_thread: Literal[True],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> Coroutine[None, None, Run]:
|
||||
...
|
||||
|
||||
|
@ -2575,6 +2626,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
organization: Optional[str],
|
||||
client,
|
||||
arun_thread: Optional[Literal[False]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> Run:
|
||||
...
|
||||
|
||||
|
@ -2597,8 +2649,29 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
organization: Optional[str],
|
||||
client=None,
|
||||
arun_thread=None,
|
||||
event_handler: Optional[AssistantEventHandler] = None,
|
||||
):
|
||||
if arun_thread is not None and arun_thread == True:
|
||||
if stream is not None and stream == True:
|
||||
_client = self.async_get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
return self.async_run_thread_stream(
|
||||
client=_client,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
)
|
||||
return self.arun_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
|
@ -2624,6 +2697,19 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
client=client,
|
||||
)
|
||||
|
||||
if stream is not None and stream == True:
|
||||
return self.run_thread_stream(
|
||||
client=openai_client,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
|
|
|
@ -487,9 +487,16 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
try:
|
||||
response = await self.async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise PredibaseError(
|
||||
status_code=e.response.status_code, message=e.response.text
|
||||
)
|
||||
except Exception as e:
|
||||
raise PredibaseError(status_code=500, message=str(e))
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
|
|
|
@ -40,6 +40,12 @@ model_list:
|
|||
vertex_project: my-project-9d5c
|
||||
vertex_location: us-central1
|
||||
|
||||
assistant_settings:
|
||||
custom_llm_provider: openai
|
||||
litellm_params:
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
|
||||
router_settings:
|
||||
enable_pre_call_checks: true
|
||||
|
||||
|
|
|
@ -117,7 +117,13 @@ import pydantic
|
|||
from litellm.proxy._types import *
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm.router import LiteLLM_Params, Deployment, updateDeployment, ModelGroupInfo
|
||||
from litellm.router import (
|
||||
LiteLLM_Params,
|
||||
Deployment,
|
||||
updateDeployment,
|
||||
ModelGroupInfo,
|
||||
AssistantsTypedDict,
|
||||
)
|
||||
from litellm.router import ModelInfo as RouterModelInfo
|
||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
|
@ -2859,6 +2865,17 @@ class ProxyConfig:
|
|||
if "ollama" in litellm_model_name and litellm_model_api_base is None:
|
||||
run_ollama_serve()
|
||||
|
||||
## ASSISTANT SETTINGS
|
||||
assistants_config: Optional[AssistantsTypedDict] = None
|
||||
assistant_settings = config.get("assistant_settings", None)
|
||||
if assistant_settings:
|
||||
for k, v in assistant_settings["litellm_params"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
_v = v.replace("os.environ/", "")
|
||||
v = os.getenv(_v)
|
||||
assistant_settings["litellm_params"][k] = v
|
||||
assistants_config = AssistantsTypedDict(**assistant_settings) # type: ignore
|
||||
|
||||
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
||||
router_settings = config.get("router_settings", None)
|
||||
if router_settings and isinstance(router_settings, dict):
|
||||
|
@ -2874,7 +2891,9 @@ class ProxyConfig:
|
|||
for k, v in router_settings.items():
|
||||
if k in available_args:
|
||||
router_params[k] = v
|
||||
router = litellm.Router(**router_params) # type:ignore
|
||||
router = litellm.Router(
|
||||
**router_params, assistants_config=assistants_config
|
||||
) # type:ignore
|
||||
return router, model_list, general_settings
|
||||
|
||||
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
||||
|
@ -3612,6 +3631,60 @@ def data_generator(response):
|
|||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
|
||||
async def async_assistants_data_generator(
|
||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||
):
|
||||
verbose_proxy_logger.debug("inside generator")
|
||||
try:
|
||||
start_time = time.time()
|
||||
async with response as chunk:
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=chunk
|
||||
)
|
||||
|
||||
# chunk = chunk.model_dump_json(exclude_none=True)
|
||||
async for c in chunk:
|
||||
c = c.model_dump_json(exclude_none=True)
|
||||
try:
|
||||
yield f"data: {c}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data: {str(e)}\n\n"
|
||||
|
||||
# Streaming is done, yield the [DONE] chunk
|
||||
done_message = "[DONE]"
|
||||
yield f"data: {done_message}\n\n"
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=request_data,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
if user_debug:
|
||||
traceback.print_exc()
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
|
||||
proxy_exception = ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
error_returned = json.dumps({"error": proxy_exception.to_dict()})
|
||||
yield f"data: {error_returned}\n\n"
|
||||
|
||||
|
||||
async def async_data_generator(
|
||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||
):
|
||||
|
@ -5353,7 +5426,6 @@ async def get_assistants(
|
|||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = { # type: ignore
|
||||
|
@ -5411,9 +5483,7 @@ async def get_assistants(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.aget_assistants(
|
||||
custom_llm_provider="openai", client=None, **data
|
||||
)
|
||||
response = await llm_router.aget_assistants(**data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -5485,7 +5555,6 @@ async def create_threads(
|
|||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = { # type: ignore
|
||||
|
@ -5543,9 +5612,7 @@ async def create_threads(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.acreate_thread(
|
||||
custom_llm_provider="openai", client=None, **data
|
||||
)
|
||||
response = await llm_router.acreate_thread(**data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -5673,9 +5740,7 @@ async def get_thread(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.aget_thread(
|
||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||
)
|
||||
response = await llm_router.aget_thread(thread_id=thread_id, **data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -5806,9 +5871,7 @@ async def add_messages(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.a_add_message(
|
||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||
)
|
||||
response = await llm_router.a_add_message(thread_id=thread_id, **data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -5935,9 +5998,7 @@ async def get_messages(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.aget_messages(
|
||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||
)
|
||||
response = await llm_router.aget_messages(thread_id=thread_id, **data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -5984,12 +6045,12 @@ async def get_messages(
|
|||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
@router.post(
|
||||
"/v1/threads/{thread_id}/runs",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["assistants"],
|
||||
)
|
||||
@router.get(
|
||||
@router.post(
|
||||
"/threads/{thread_id}/runs",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["assistants"],
|
||||
|
@ -6066,8 +6127,18 @@ async def run_thread(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.arun_thread(
|
||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||
response = await llm_router.arun_thread(thread_id=thread_id, **data)
|
||||
|
||||
if (
|
||||
"stream" in data and data["stream"] == True
|
||||
): # use generate_responses to stream responses
|
||||
return StreamingResponse(
|
||||
async_assistants_data_generator(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
response=response,
|
||||
request_data=data,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
|
|
|
@ -51,6 +51,7 @@ from litellm.types.router import (
|
|||
AlertingConfig,
|
||||
DeploymentTypedDict,
|
||||
ModelGroupInfo,
|
||||
AssistantsTypedDict,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
|
@ -78,6 +79,8 @@ class Router:
|
|||
def __init__(
|
||||
self,
|
||||
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
|
||||
## ASSISTANTS API ##
|
||||
assistants_config: Optional[AssistantsTypedDict] = None,
|
||||
## CACHING ##
|
||||
redis_url: Optional[str] = None,
|
||||
redis_host: Optional[str] = None,
|
||||
|
@ -212,6 +215,7 @@ class Router:
|
|||
elif debug_level == "DEBUG":
|
||||
verbose_router_logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.assistants_config = assistants_config
|
||||
self.deployment_names: List = (
|
||||
[]
|
||||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||
|
@ -1831,31 +1835,56 @@ class Router:
|
|||
|
||||
async def aget_assistants(
|
||||
self,
|
||||
custom_llm_provider: Literal["openai"],
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> AsyncCursorPage[Assistant]:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.aget_assistants(
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
async def acreate_thread(
|
||||
self,
|
||||
custom_llm_provider: Literal["openai"],
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> Thread:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
return await litellm.acreate_thread(
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
async def aget_thread(
|
||||
self,
|
||||
custom_llm_provider: Literal["openai"],
|
||||
thread_id: str,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> Thread:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
return await litellm.aget_thread(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
|
@ -1865,15 +1894,24 @@ class Router:
|
|||
|
||||
async def a_add_message(
|
||||
self,
|
||||
custom_llm_provider: Literal["openai"],
|
||||
thread_id: str,
|
||||
role: Literal["user", "assistant"],
|
||||
content: str,
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> OpenAIMessage:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.a_add_message(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
|
@ -1887,11 +1925,19 @@ class Router:
|
|||
|
||||
async def aget_messages(
|
||||
self,
|
||||
custom_llm_provider: Literal["openai"],
|
||||
thread_id: str,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> AsyncCursorPage[OpenAIMessage]:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
return await litellm.aget_messages(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
|
@ -1901,9 +1947,9 @@ class Router:
|
|||
|
||||
async def arun_thread(
|
||||
self,
|
||||
custom_llm_provider: Literal["openai"],
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
additional_instructions: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
|
@ -1913,6 +1959,16 @@ class Router:
|
|||
client: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> Run:
|
||||
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.arun_thread(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
|
|
|
@ -18,7 +18,10 @@ from litellm.llms.openai import (
|
|||
OpenAIMessage as Message,
|
||||
AsyncCursorPage,
|
||||
SyncCursorPage,
|
||||
AssistantEventHandler,
|
||||
AsyncAssistantEventHandler,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
"""
|
||||
V0 Scope:
|
||||
|
@ -129,10 +132,26 @@ async def test_add_message_litellm(sync_mode, provider):
|
|||
assert isinstance(added_message, Message)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["openai", "azure"])
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"provider",
|
||||
[
|
||||
"azure",
|
||||
"openai",
|
||||
],
|
||||
) #
|
||||
@pytest.mark.parametrize(
|
||||
"sync_mode",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"is_streaming",
|
||||
[True, False],
|
||||
) #
|
||||
@pytest.mark.asyncio
|
||||
async def test_aarun_thread_litellm(sync_mode, provider):
|
||||
async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
|
||||
"""
|
||||
- Get Assistants
|
||||
- Create thread
|
||||
|
@ -163,8 +182,16 @@ async def test_aarun_thread_litellm(sync_mode, provider):
|
|||
if sync_mode:
|
||||
added_message = litellm.add_message(**data)
|
||||
|
||||
run = litellm.run_thread(assistant_id=assistant_id, **data)
|
||||
|
||||
if is_streaming:
|
||||
run = litellm.run_thread_stream(assistant_id=assistant_id, **data)
|
||||
with run as run:
|
||||
assert isinstance(run, AssistantEventHandler)
|
||||
print(run)
|
||||
run.until_done()
|
||||
else:
|
||||
run = litellm.run_thread(
|
||||
assistant_id=assistant_id, stream=is_streaming, **data
|
||||
)
|
||||
if run.status == "completed":
|
||||
messages = litellm.get_messages(
|
||||
thread_id=_new_thread.id, custom_llm_provider=provider
|
||||
|
@ -176,8 +203,21 @@ async def test_aarun_thread_litellm(sync_mode, provider):
|
|||
else:
|
||||
added_message = await litellm.a_add_message(**data)
|
||||
|
||||
if is_streaming:
|
||||
run = litellm.arun_thread_stream(assistant_id=assistant_id, **data)
|
||||
async with run as run:
|
||||
print(f"run: {run}")
|
||||
assert isinstance(
|
||||
run,
|
||||
AsyncAssistantEventHandler,
|
||||
)
|
||||
print(run)
|
||||
run.until_done()
|
||||
else:
|
||||
run = await litellm.arun_thread(
|
||||
custom_llm_provider=provider, thread_id=thread_id, assistant_id=assistant_id
|
||||
custom_llm_provider=provider,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
)
|
||||
|
||||
if run.status == "completed":
|
||||
|
|
|
@ -13,6 +13,12 @@ from openai.types.beta.threads.message import Message as OpenAIMessage
|
|||
from openai.types.beta.thread_create_params import (
|
||||
Message as OpenAICreateThreadParamsMessage,
|
||||
)
|
||||
from openai.lib.streaming._assistants import (
|
||||
AssistantEventHandler,
|
||||
AssistantStreamManager,
|
||||
AsyncAssistantStreamManager,
|
||||
AsyncAssistantEventHandler,
|
||||
)
|
||||
from openai.types.beta.assistant_tool_param import AssistantToolParam
|
||||
from openai.types.beta.threads.run import Run
|
||||
from openai.types.beta.assistant import Assistant
|
||||
|
|
|
@ -446,3 +446,8 @@ class ModelGroupInfo(BaseModel):
|
|||
supports_vision: bool = Field(default=False)
|
||||
supports_function_calling: bool = Field(default=False)
|
||||
supported_openai_params: List[str] = Field(default=[])
|
||||
|
||||
|
||||
class AssistantsTypedDict(TypedDict):
|
||||
custom_llm_provider: Literal["azure", "openai"]
|
||||
litellm_params: LiteLLMParamsTypedDict
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue