This commit is contained in:
Krrish Dholakia 2024-06-03 18:47:05 -07:00
parent 93c9ea160d
commit 7163bce37b
8 changed files with 1382 additions and 91 deletions

View file

@ -783,7 +783,11 @@ from .llms.openai import (
MistralConfig, MistralConfig,
DeepInfraConfig, DeepInfraConfig,
) )
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError from .llms.azure import (
AzureOpenAIConfig,
AzureOpenAIError,
AzureOpenAIAssistantsAPIConfig,
)
from .llms.watsonx import IBMWatsonXAIConfig from .llms.watsonx import IBMWatsonXAIConfig
from .main import * # type: ignore from .main import * # type: ignore
from .integrations import * from .integrations import *

View file

@ -4,21 +4,29 @@ from typing import Iterable
from functools import partial from functools import partial
import os, asyncio, contextvars import os, asyncio, contextvars
import litellm import litellm
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
from litellm import client from litellm import client
from litellm.utils import supports_httpx_timeout, exception_type, get_llm_provider from litellm.utils import (
supports_httpx_timeout,
exception_type,
get_llm_provider,
get_secret,
)
from ..llms.openai import OpenAIAssistantsAPI from ..llms.openai import OpenAIAssistantsAPI
from ..llms.azure import AzureAssistantsAPI
from ..types.llms.openai import * from ..types.llms.openai import *
from ..types.router import * from ..types.router import *
from .utils import get_optional_params_add_message
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
openai_assistants_api = OpenAIAssistantsAPI() openai_assistants_api = OpenAIAssistantsAPI()
azure_assistants_api = AzureAssistantsAPI()
### ASSISTANTS ### ### ASSISTANTS ###
async def aget_assistants( async def aget_assistants(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai", "azure"],
client: Optional[AsyncOpenAI] = None, client: Optional[AsyncOpenAI] = None,
**kwargs, **kwargs,
) -> AsyncCursorPage[Assistant]: ) -> AsyncCursorPage[Assistant]:
@ -55,12 +63,21 @@ async def aget_assistants(
def get_assistants( def get_assistants(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai", "azure"],
client: Optional[OpenAI] = None, client: Optional[Any] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
**kwargs, **kwargs,
) -> SyncCursorPage[Assistant]: ) -> SyncCursorPage[Assistant]:
aget_assistants = kwargs.pop("aget_assistants", None) aget_assistants: Optional[bool] = kwargs.pop("aget_assistants", None)
optional_params = GenericLiteLLMParams(**kwargs) if aget_assistants is not None and not isinstance(aget_assistants, bool):
raise Exception(
"Invalid value passed in for aget_assistants. Only bool or None allowed"
)
optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -99,6 +116,7 @@ def get_assistants(
or litellm.openai_key or litellm.openai_key
or os.getenv("OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY")
) )
response = openai_assistants_api.get_assistants( response = openai_assistants_api.get_assistants(
api_base=api_base, api_base=api_base,
api_key=api_key, api_key=api_key,
@ -106,7 +124,43 @@ def get_assistants(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
organization=organization, organization=organization,
client=client, client=client,
aget_assistants=aget_assistants, aget_assistants=aget_assistants, # type: ignore
) # type: ignore
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
response = azure_assistants_api.get_assistants(
api_base=api_base,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=optional_params.max_retries,
client=client,
aget_assistants=aget_assistants, # type: ignore
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -127,7 +181,9 @@ def get_assistants(
### THREADS ### ### THREADS ###
async def acreate_thread(custom_llm_provider: Literal["openai"], **kwargs) -> Thread: async def acreate_thread(
custom_llm_provider: Literal["openai", "azure"], **kwargs
) -> Thread:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ### ### PASS ARGS TO GET ASSISTANTS ###
kwargs["acreate_thread"] = True kwargs["acreate_thread"] = True
@ -161,7 +217,7 @@ async def acreate_thread(custom_llm_provider: Literal["openai"], **kwargs) -> Th
def create_thread( def create_thread(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai", "azure"],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None, messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
tool_resources: Optional[OpenAICreateThreadParamsToolResources] = None, tool_resources: Optional[OpenAICreateThreadParamsToolResources] = None,
@ -241,6 +297,47 @@ def create_thread(
client=client, client=client,
acreate_thread=acreate_thread, acreate_thread=acreate_thread,
) )
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
if isinstance(client, OpenAI):
client = None # only pass client if it's AzureOpenAI
response = azure_assistants_api.create_thread(
messages=messages,
metadata=metadata,
api_base=api_base,
api_key=api_key,
azure_ad_token=azure_ad_token,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
client=client,
acreate_thread=acreate_thread,
) # type :ignore
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format( message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
@ -254,11 +351,11 @@ def create_thread(
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
), ),
) )
return response return response # type: ignore
async def aget_thread( async def aget_thread(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai", "azure"],
thread_id: str, thread_id: str,
client: Optional[AsyncOpenAI] = None, client: Optional[AsyncOpenAI] = None,
**kwargs, **kwargs,
@ -296,9 +393,9 @@ async def aget_thread(
def get_thread( def get_thread(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai", "azure"],
thread_id: str, thread_id: str,
client: Optional[OpenAI] = None, client=None,
**kwargs, **kwargs,
) -> Thread: ) -> Thread:
"""Get the thread object, given a thread_id""" """Get the thread object, given a thread_id"""
@ -342,6 +439,7 @@ def get_thread(
or litellm.openai_key or litellm.openai_key
or os.getenv("OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY")
) )
response = openai_assistants_api.get_thread( response = openai_assistants_api.get_thread(
thread_id=thread_id, thread_id=thread_id,
api_base=api_base, api_base=api_base,
@ -352,6 +450,46 @@ def get_thread(
client=client, client=client,
aget_thread=aget_thread, aget_thread=aget_thread,
) )
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
if isinstance(client, OpenAI):
client = None # only pass client if it's AzureOpenAI
response = azure_assistants_api.get_thread(
thread_id=thread_id,
api_base=api_base,
api_key=api_key,
azure_ad_token=azure_ad_token,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
client=client,
aget_thread=aget_thread,
)
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'get_thread'. Only 'openai' is supported.".format( message="LiteLLM doesn't support {} for 'get_thread'. Only 'openai' is supported.".format(
@ -365,20 +503,20 @@ def get_thread(
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
), ),
) )
return response return response # type: ignore
### MESSAGES ### ### MESSAGES ###
async def a_add_message( async def a_add_message(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai", "azure"],
thread_id: str, thread_id: str,
role: Literal["user", "assistant"], role: Literal["user", "assistant"],
content: str, content: str,
attachments: Optional[List[Attachment]] = None, attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
client: Optional[AsyncOpenAI] = None, client=None,
**kwargs, **kwargs,
) -> OpenAIMessage: ) -> OpenAIMessage:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -425,22 +563,30 @@ async def a_add_message(
def add_message( def add_message(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai", "azure"],
thread_id: str, thread_id: str,
role: Literal["user", "assistant"], role: Literal["user", "assistant"],
content: str, content: str,
attachments: Optional[List[Attachment]] = None, attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
client: Optional[OpenAI] = None, client=None,
**kwargs, **kwargs,
) -> OpenAIMessage: ) -> OpenAIMessage:
### COMMON OBJECTS ### ### COMMON OBJECTS ###
a_add_message = kwargs.pop("a_add_message", None) a_add_message = kwargs.pop("a_add_message", None)
message_data = MessageData( _message_data = MessageData(
role=role, content=content, attachments=attachments, metadata=metadata role=role, content=content, attachments=attachments, metadata=metadata
) )
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
message_data = get_optional_params_add_message(
role=_message_data["role"],
content=_message_data["content"],
attachments=_message_data["attachments"],
metadata=_message_data["metadata"],
custom_llm_provider=custom_llm_provider,
)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default # set timeout for 10 minutes by default
@ -489,6 +635,44 @@ def add_message(
client=client, client=client,
a_add_message=a_add_message, a_add_message=a_add_message,
) )
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
response = azure_assistants_api.add_message(
thread_id=thread_id,
message_data=message_data,
api_base=api_base,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=optional_params.max_retries,
client=client,
a_add_message=a_add_message,
)
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format( message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
@ -503,11 +687,11 @@ def add_message(
), ),
) )
return response return response # type: ignore
async def aget_messages( async def aget_messages(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai", "azure"],
thread_id: str, thread_id: str,
client: Optional[AsyncOpenAI] = None, client: Optional[AsyncOpenAI] = None,
**kwargs, **kwargs,
@ -552,9 +736,9 @@ async def aget_messages(
def get_messages( def get_messages(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai", "azure"],
thread_id: str, thread_id: str,
client: Optional[OpenAI] = None, client: Optional[Any] = None,
**kwargs, **kwargs,
) -> SyncCursorPage[OpenAIMessage]: ) -> SyncCursorPage[OpenAIMessage]:
aget_messages = kwargs.pop("aget_messages", None) aget_messages = kwargs.pop("aget_messages", None)
@ -607,6 +791,43 @@ def get_messages(
client=client, client=client,
aget_messages=aget_messages, aget_messages=aget_messages,
) )
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
response = azure_assistants_api.get_messages(
thread_id=thread_id,
api_base=api_base,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=optional_params.max_retries,
client=client,
aget_messages=aget_messages,
)
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'get_messages'. Only 'openai' is supported.".format( message="LiteLLM doesn't support {} for 'get_messages'. Only 'openai' is supported.".format(
@ -621,12 +842,12 @@ def get_messages(
), ),
) )
return response return response # type: ignore
### RUNS ### ### RUNS ###
async def arun_thread( async def arun_thread(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai", "azure"],
thread_id: str, thread_id: str,
assistant_id: str, assistant_id: str,
additional_instructions: Optional[str] = None, additional_instructions: Optional[str] = None,
@ -635,7 +856,7 @@ async def arun_thread(
model: Optional[str] = None, model: Optional[str] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None, tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[AsyncOpenAI] = None, client: Optional[Any] = None,
**kwargs, **kwargs,
) -> Run: ) -> Run:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -685,7 +906,7 @@ async def arun_thread(
def run_thread( def run_thread(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai", "azure"],
thread_id: str, thread_id: str,
assistant_id: str, assistant_id: str,
additional_instructions: Optional[str] = None, additional_instructions: Optional[str] = None,
@ -694,7 +915,7 @@ def run_thread(
model: Optional[str] = None, model: Optional[str] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None, tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[OpenAI] = None, client: Optional[Any] = None,
**kwargs, **kwargs,
) -> Run: ) -> Run:
"""Run a given thread + assistant.""" """Run a given thread + assistant."""
@ -755,6 +976,50 @@ def run_thread(
client=client, client=client,
arun_thread=arun_thread, arun_thread=arun_thread,
) )
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
response = azure_assistants_api.run_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
api_base=str(api_base),
api_key=str(api_key),
api_version=str(api_version),
azure_ad_token=str(azure_ad_token),
timeout=timeout,
max_retries=optional_params.max_retries,
client=client,
arun_thread=arun_thread,
) # type: ignore
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'run_thread'. Only 'openai' is supported.".format( message="LiteLLM doesn't support {} for 'run_thread'. Only 'openai' is supported.".format(
@ -768,4 +1033,4 @@ def run_thread(
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
), ),
) )
return response return response # type: ignore

158
litellm/assistants/utils.py Normal file
View file

@ -0,0 +1,158 @@
import litellm
from typing import Optional, Union
from ..types.llms.openai import *
def get_optional_params_add_message(
role: Optional[str],
content: Optional[
Union[
str,
List[
Union[
MessageContentTextObject,
MessageContentImageFileObject,
MessageContentImageURLObject,
]
],
]
],
attachments: Optional[List[Attachment]],
metadata: Optional[dict],
custom_llm_provider: str,
**kwargs,
):
"""
Azure doesn't support 'attachments' for creating a message
Reference - https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
"""
passed_params = locals()
custom_llm_provider = passed_params.pop("custom_llm_provider")
special_params = passed_params.pop("kwargs")
for k, v in special_params.items():
passed_params[k] = v
default_params = {
"role": None,
"content": None,
"attachments": None,
"metadata": None,
}
non_default_params = {
k: v
for k, v in passed_params.items()
if (k in default_params and v != default_params[k])
}
optional_params = {}
## raise exception if non-default value passed for non-openai/azure embedding calls
def _check_valid_arg(supported_params):
if len(non_default_params.keys()) > 0:
keys = list(non_default_params.keys())
for k in keys:
if (
litellm.drop_params is True and k not in supported_params
): # drop the unsupported non-default values
non_default_params.pop(k, None)
elif k not in supported_params:
raise litellm.utils.UnsupportedParamsError(
status_code=500,
message="k={}, not supported by {}. Supported params={}. To drop it from the call, set `litellm.drop_params = True`.".format(
k, custom_llm_provider, supported_params
),
)
return non_default_params
if custom_llm_provider == "openai":
optional_params = non_default_params
elif custom_llm_provider == "azure":
supported_params = (
litellm.AzureOpenAIAssistantsAPIConfig().get_supported_openai_create_message_params()
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.AzureOpenAIAssistantsAPIConfig().map_openai_params_create_message_params(
non_default_params=non_default_params, optional_params=optional_params
)
for k in passed_params.keys():
if k not in default_params.keys():
optional_params[k] = passed_params[k]
return optional_params
def get_optional_params_image_gen(
n: Optional[int] = None,
quality: Optional[str] = None,
response_format: Optional[str] = None,
size: Optional[str] = None,
style: Optional[str] = None,
user: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
# retrieve all parameters passed to the function
passed_params = locals()
custom_llm_provider = passed_params.pop("custom_llm_provider")
special_params = passed_params.pop("kwargs")
for k, v in special_params.items():
passed_params[k] = v
default_params = {
"n": None,
"quality": None,
"response_format": None,
"size": None,
"style": None,
"user": None,
}
non_default_params = {
k: v
for k, v in passed_params.items()
if (k in default_params and v != default_params[k])
}
optional_params = {}
## raise exception if non-default value passed for non-openai/azure embedding calls
def _check_valid_arg(supported_params):
if len(non_default_params.keys()) > 0:
keys = list(non_default_params.keys())
for k in keys:
if (
litellm.drop_params is True and k not in supported_params
): # drop the unsupported non-default values
non_default_params.pop(k, None)
elif k not in supported_params:
raise UnsupportedParamsError(
status_code=500,
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
)
return non_default_params
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider in litellm.openai_compatible_providers
):
optional_params = non_default_params
elif custom_llm_provider == "bedrock":
supported_params = ["size"]
_check_valid_arg(supported_params=supported_params)
if size is not None:
width, height = size.split("x")
optional_params["width"] = int(width)
optional_params["height"] = int(height)
elif custom_llm_provider == "vertex_ai":
supported_params = ["n"]
"""
All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
"""
_check_valid_arg(supported_params=supported_params)
if n is not None:
optional_params["sampleCount"] = int(n)
for k in passed_params.keys():
if k not in default_params.keys():
optional_params[k] = passed_params[k]
return optional_params

View file

@ -1,4 +1,5 @@
from typing import Optional, Union, Any, Literal from typing import Optional, Union, Any, Literal, Coroutine, Iterable
from typing_extensions import overload
import types, requests import types, requests
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ( from litellm.utils import (
@ -18,6 +19,18 @@ from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTra
from openai import AzureOpenAI, AsyncAzureOpenAI from openai import AzureOpenAI, AsyncAzureOpenAI
import uuid import uuid
import os import os
from ..types.llms.openai import (
AsyncCursorPage,
AssistantToolParam,
SyncCursorPage,
Assistant,
MessageData,
OpenAIMessage,
OpenAICreateThreadParamsMessage,
Thread,
AssistantToolParam,
Run,
)
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
@ -114,6 +127,68 @@ class AzureOpenAIConfig(OpenAIConfig):
return ["europe", "sweden", "switzerland", "france", "uk"] return ["europe", "sweden", "switzerland", "france", "uk"]
class AzureOpenAIAssistantsAPIConfig:
"""
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
"""
def __init__(
self,
) -> None:
pass
def get_supported_openai_create_message_params(self):
return [
"role",
"content",
"attachments",
"metadata",
]
def map_openai_params_create_message_params(
self, non_default_params: dict, optional_params: dict
):
for param, value in non_default_params.items():
if param == "role":
optional_params["role"] = value
if param == "metadata":
optional_params["metadata"] = value
elif param == "content": # only string accepted
if isinstance(value, str):
optional_params["content"] = value
else:
raise litellm.utils.UnsupportedParamsError(
message="Azure only accepts content as a string.",
status_code=400,
)
elif (
param == "attachments"
): # this is a v2 param. Azure currently supports the old 'file_id's param
file_ids: List[str] = []
if isinstance(value, list):
for item in value:
if "file_id" in item:
file_ids.append(item["file_id"])
else:
if litellm.drop_params == True:
pass
else:
raise litellm.utils.UnsupportedParamsError(
message="Azure doesn't support {}. To drop it from the call, set `litellm.drop_params = True.".format(
value
),
status_code=400,
)
else:
raise litellm.utils.UnsupportedParamsError(
message="Invalid param. attachments should always be a list. Got={}, Expected=List. Raw value={}".format(
type(value), value
),
status_code=400,
)
return optional_params
def select_azure_base_url_or_endpoint(azure_client_params: dict): def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = { # azure_client_params = {
# "api_version": api_version, # "api_version": api_version,
@ -172,9 +247,7 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
possible_azure_ad_token = req_token.json().get("access_token", None) possible_azure_ad_token = req_token.json().get("access_token", None)
if possible_azure_ad_token is None: if possible_azure_ad_token is None:
raise AzureOpenAIError( raise AzureOpenAIError(status_code=422, message="Azure AD Token not returned")
status_code=422, message="Azure AD Token not returned"
)
return possible_azure_ad_token return possible_azure_ad_token
@ -245,7 +318,9 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(
azure_ad_token
)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
@ -1192,3 +1267,741 @@ class AzureChatCompletion(BaseLLM):
response["x-ms-region"] = completion.headers["x-ms-region"] response["x-ms-region"] = completion.headers["x-ms-region"]
return response return response
class AzureAssistantsAPI(BaseLLM):
def __init__(self) -> None:
super().__init__()
def get_azure_client(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
) -> AzureOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
azure_openai_client = AzureOpenAI(**data) # type: ignore
else:
azure_openai_client = client
return azure_openai_client
def async_get_azure_client(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
) -> AsyncAzureOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
else:
azure_openai_client = client
return azure_openai_client
### ASSISTANTS ###
async def async_get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
) -> AsyncCursorPage[Assistant]:
azure_openai_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = await azure_openai_client.beta.assistants.list()
return response
# fmt: off
@overload
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
aget_assistants: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
...
@overload
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI],
aget_assistants: Optional[Literal[False]],
) -> SyncCursorPage[Assistant]:
...
# fmt: on
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client=None,
aget_assistants=None,
):
if aget_assistants is not None and aget_assistants == True:
return self.async_get_assistants(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
api_base=api_base,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
api_version=api_version,
)
response = azure_openai_client.beta.assistants.list()
return response
### MESSAGES ###
async def a_add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
) -> OpenAIMessage:
openai_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data # type: ignore
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
# fmt: off
@overload
def add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
a_add_message: Literal[True],
) -> Coroutine[None, None, OpenAIMessage]:
...
@overload
def add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI],
a_add_message: Optional[Literal[False]],
) -> OpenAIMessage:
...
# fmt: on
def add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client=None,
a_add_message: Optional[bool] = None,
):
if a_add_message is not None and a_add_message == True:
return self.a_add_message(
thread_id=thread_id,
message_data=message_data,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
openai_client = self.get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data # type: ignore
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
async def async_get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
) -> AsyncCursorPage[OpenAIMessage]:
openai_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
# fmt: off
@overload
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
aget_messages: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
...
@overload
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI],
aget_messages: Optional[Literal[False]],
) -> SyncCursorPage[OpenAIMessage]:
...
# fmt: on
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client=None,
aget_messages=None,
):
if aget_messages is not None and aget_messages == True:
return self.async_get_messages(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
openai_client = self.get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
### THREADS ###
async def async_create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
) -> Thread:
openai_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())
# fmt: off
@overload
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AsyncAzureOpenAI],
acreate_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AzureOpenAI],
acreate_thread: Optional[Literal[False]],
) -> Thread:
...
# fmt: on
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client=None,
acreate_thread=None,
):
"""
Here's an example:
```
from litellm.llms.openai import OpenAIAssistantsAPI, MessageData
# create thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
openai_api.create_thread(messages=[message])
```
"""
if acreate_thread is not None and acreate_thread == True:
return self.async_create_thread(
metadata=metadata,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
messages=messages,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())
async def async_get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
) -> Thread:
openai_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
# fmt: off
@overload
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
aget_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI],
aget_thread: Optional[Literal[False]],
) -> Thread:
...
# fmt: on
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client=None,
aget_thread=None,
):
if aget_thread is not None and aget_thread == True:
return self.async_get_thread(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
openai_client = self.get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
# def delete_thread(self):
# pass
### RUNS ###
async def arun_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
) -> Run:
openai_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
api_version=api_version,
azure_ad_token=azure_ad_token,
client=client,
)
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response
# fmt: off
@overload
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
arun_thread: Literal[True],
) -> Coroutine[None, None, Run]:
...
@overload
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI],
arun_thread: Optional[Literal[False]],
) -> Run:
...
# fmt: on
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client=None,
arun_thread=None,
):
if arun_thread is not None and arun_thread == True:
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
openai_client = self.get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response

View file

@ -2088,7 +2088,7 @@ class OpenAIAssistantsAPI(BaseLLM):
async def a_add_message( async def a_add_message(
self, self,
thread_id: str, thread_id: str,
message_data: MessageData, message_data: dict,
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
@ -2123,7 +2123,7 @@ class OpenAIAssistantsAPI(BaseLLM):
def add_message( def add_message(
self, self,
thread_id: str, thread_id: str,
message_data: MessageData, message_data: dict,
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
@ -2138,7 +2138,7 @@ class OpenAIAssistantsAPI(BaseLLM):
def add_message( def add_message(
self, self,
thread_id: str, thread_id: str,
message_data: MessageData, message_data: dict,
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
@ -2154,7 +2154,7 @@ class OpenAIAssistantsAPI(BaseLLM):
def add_message( def add_message(
self, self,
thread_id: str, thread_id: str,
message_data: MessageData, message_data: dict,
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
@ -2552,7 +2552,7 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[AsyncOpenAI], client,
arun_thread: Literal[True], arun_thread: Literal[True],
) -> Coroutine[None, None, Run]: ) -> Coroutine[None, None, Run]:
... ...
@ -2573,7 +2573,7 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[OpenAI], client,
arun_thread: Optional[Literal[False]], arun_thread: Optional[Literal[False]],
) -> Run: ) -> Run:
... ...

View file

@ -1905,7 +1905,7 @@ class Router:
model: Optional[str] = None, model: Optional[str] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None, tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[AsyncOpenAI] = None, client: Optional[Any] = None,
**kwargs, **kwargs,
) -> Run: ) -> Run:
return await litellm.arun_thread( return await litellm.arun_thread(

View file

@ -17,6 +17,7 @@ from litellm.llms.openai import (
Thread, Thread,
OpenAIMessage as Message, OpenAIMessage as Message,
AsyncCursorPage, AsyncCursorPage,
SyncCursorPage,
) )
""" """
@ -27,27 +28,43 @@ V0 Scope:
""" """
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize(
"sync_mode",
[True, False],
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_get_assistants(): async def test_get_assistants(provider, sync_mode):
assistants = await litellm.aget_assistants(custom_llm_provider="openai") data = {
assert isinstance(assistants, AsyncCursorPage) "custom_llm_provider": provider,
}
if provider == "azure":
data["api_version"] = "2024-02-15-preview"
if sync_mode == True:
assistants = litellm.get_assistants(**data)
assert isinstance(assistants, SyncCursorPage)
else:
assistants = await litellm.aget_assistants(**data)
assert isinstance(assistants, AsyncCursorPage)
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_thread_litellm(sync_mode) -> Thread: async def test_create_thread_litellm(sync_mode, provider) -> Thread:
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
data = {
"custom_llm_provider": provider,
"message": [message],
}
if provider == "azure":
data["api_version"] = "2024-02-15-preview"
if sync_mode: if sync_mode:
new_thread = create_thread( new_thread = create_thread(**data)
custom_llm_provider="openai",
messages=[message], # type: ignore
)
else: else:
new_thread = await litellm.acreate_thread( new_thread = await litellm.acreate_thread(**data)
custom_llm_provider="openai",
messages=[message], # type: ignore
)
assert isinstance( assert isinstance(
new_thread, Thread new_thread, Thread
@ -56,26 +73,28 @@ async def test_create_thread_litellm(sync_mode) -> Thread:
return new_thread return new_thread
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_thread_litellm(sync_mode): async def test_get_thread_litellm(provider, sync_mode):
new_thread = test_create_thread_litellm(sync_mode) new_thread = test_create_thread_litellm(sync_mode, provider)
if asyncio.iscoroutine(new_thread): if asyncio.iscoroutine(new_thread):
_new_thread = await new_thread _new_thread = await new_thread
else: else:
_new_thread = new_thread _new_thread = new_thread
data = {
"custom_llm_provider": provider,
"thread_id": _new_thread.id,
}
if provider == "azure":
data["api_version"] = "2024-02-15-preview"
if sync_mode: if sync_mode:
received_thread = get_thread( received_thread = get_thread(**data)
custom_llm_provider="openai",
thread_id=_new_thread.id,
)
else: else:
received_thread = await litellm.aget_thread( received_thread = await litellm.aget_thread(**data)
custom_llm_provider="openai",
thread_id=_new_thread.id,
)
assert isinstance( assert isinstance(
received_thread, Thread received_thread, Thread
@ -83,11 +102,12 @@ async def test_get_thread_litellm(sync_mode):
return new_thread return new_thread
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_message_litellm(sync_mode): async def test_add_message_litellm(sync_mode, provider):
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
new_thread = test_create_thread_litellm(sync_mode) new_thread = test_create_thread_litellm(sync_mode, provider)
if asyncio.iscoroutine(new_thread): if asyncio.iscoroutine(new_thread):
_new_thread = await new_thread _new_thread = await new_thread
@ -95,37 +115,38 @@ async def test_add_message_litellm(sync_mode):
_new_thread = new_thread _new_thread = new_thread
# add message to thread # add message to thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
data = {"custom_llm_provider": provider, "thread_id": _new_thread.id, **message}
if provider == "azure":
data["api_version"] = "2024-02-15-preview"
if sync_mode: if sync_mode:
added_message = litellm.add_message( added_message = litellm.add_message(**data)
thread_id=_new_thread.id, custom_llm_provider="openai", **message
)
else: else:
added_message = await litellm.a_add_message( added_message = await litellm.a_add_message(**data)
thread_id=_new_thread.id, custom_llm_provider="openai", **message
)
print(f"added message: {added_message}") print(f"added message: {added_message}")
assert isinstance(added_message, Message) assert isinstance(added_message, Message)
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_thread_litellm(sync_mode): async def test_run_thread_litellm(sync_mode, provider):
""" """
- Get Assistants - Get Assistants
- Create thread - Create thread
- Create run w/ Assistants + Thread - Create run w/ Assistants + Thread
""" """
if sync_mode: if sync_mode:
assistants = litellm.get_assistants(custom_llm_provider="openai") assistants = litellm.get_assistants(custom_llm_provider=provider)
else: else:
assistants = await litellm.aget_assistants(custom_llm_provider="openai") assistants = await litellm.aget_assistants(custom_llm_provider=provider)
## get the first assistant ### ## get the first assistant ###
assistant_id = assistants.data[0].id assistant_id = assistants.data[0].id
new_thread = test_create_thread_litellm(sync_mode=sync_mode) new_thread = test_create_thread_litellm(sync_mode=sync_mode, provider=provider)
if asyncio.iscoroutine(new_thread): if asyncio.iscoroutine(new_thread):
_new_thread = await new_thread _new_thread = await new_thread
@ -137,35 +158,31 @@ async def test_run_thread_litellm(sync_mode):
# add message to thread # add message to thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
if sync_mode: data = {"custom_llm_provider": provider, "thread_id": _new_thread.id, **message}
added_message = litellm.add_message(
thread_id=_new_thread.id, custom_llm_provider="openai", **message
)
run = litellm.run_thread( if sync_mode:
custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id added_message = litellm.add_message(**data)
)
run = litellm.run_thread(assistant_id=assistant_id, **data)
if run.status == "completed": if run.status == "completed":
messages = litellm.get_messages( messages = litellm.get_messages(
thread_id=_new_thread.id, custom_llm_provider="openai" thread_id=_new_thread.id, custom_llm_provider=provider
) )
assert isinstance(messages.data[0], Message) assert isinstance(messages.data[0], Message)
else: else:
pytest.fail("An unexpected error occurred when running the thread") pytest.fail("An unexpected error occurred when running the thread")
else: else:
added_message = await litellm.a_add_message( added_message = await litellm.a_add_message(**data)
thread_id=_new_thread.id, custom_llm_provider="openai", **message
)
run = await litellm.arun_thread( run = await litellm.arun_thread(
custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id custom_llm_provider=provider, thread_id=thread_id, assistant_id=assistant_id
) )
if run.status == "completed": if run.status == "completed":
messages = await litellm.aget_messages( messages = await litellm.aget_messages(
thread_id=_new_thread.id, custom_llm_provider="openai" thread_id=_new_thread.id, custom_llm_provider=provider
) )
assert isinstance(messages.data[0], Message) assert isinstance(messages.data[0], Message)
else: else:

View file

@ -136,9 +136,43 @@ class Attachment(TypedDict, total=False):
"""The tools to add this file to.""" """The tools to add this file to."""
class ImageFileObject(TypedDict):
file_id: Required[str]
detail: Optional[str]
class ImageURLObject(TypedDict):
url: Required[str]
detail: Optional[str]
class MessageContentTextObject(TypedDict):
type: Required[Literal["text"]]
text: str
class MessageContentImageFileObject(TypedDict):
type: Literal["image_file"]
image_file: ImageFileObject
class MessageContentImageURLObject(TypedDict):
type: Required[str]
image_url: ImageURLObject
class MessageData(TypedDict): class MessageData(TypedDict):
role: Literal["user", "assistant"] role: Literal["user", "assistant"]
content: str content: Union[
str,
List[
Union[
MessageContentTextObject,
MessageContentImageFileObject,
MessageContentImageURLObject,
]
],
]
attachments: Optional[List[Attachment]] attachments: Optional[List[Attachment]]
metadata: Optional[dict] metadata: Optional[dict]