add acreate_assistants

This commit is contained in:
Ishaan Jaff 2024-07-09 09:33:41 -07:00
parent 06926920d5
commit f4f07e13f3
2 changed files with 74 additions and 7 deletions

View file

@ -186,6 +186,45 @@ def get_assistants(
return response return response
async def acreate_assistants(
custom_llm_provider: Literal["openai", "azure"],
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Assistant:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["async_create_assistants"] = True
try:
model = kwargs.pop("model", None)
kwargs["client"] = client
# Use a partial function to pass your keyword arguments
func = partial(create_assistants, custom_llm_provider, model, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model=model, custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def create_assistants( def create_assistants(
custom_llm_provider: Literal["openai", "azure"], custom_llm_provider: Literal["openai", "azure"],
model: str, model: str,
@ -204,10 +243,14 @@ def create_assistants(
api_version: Optional[str] = None, api_version: Optional[str] = None,
**kwargs, **kwargs,
) -> Assistant: ) -> Assistant:
acreate_assistants: Optional[bool] = kwargs.pop("acreate_assistants", None) async_create_assistants: Optional[bool] = kwargs.pop(
if acreate_assistants is not None and not isinstance(acreate_assistants, bool): "async_create_assistants", None
)
if async_create_assistants is not None and not isinstance(
async_create_assistants, bool
):
raise ValueError( raise ValueError(
"Invalid value passed in for acreate_assistants. Only bool or None allowed" "Invalid value passed in for async_create_assistants. Only bool or None allowed"
) )
optional_params = GenericLiteLLMParams( optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
@ -272,7 +315,7 @@ def create_assistants(
organization=organization, organization=organization,
create_assistant_data=create_assistant_data, create_assistant_data=create_assistant_data,
client=client, client=client,
acreate_assistants=acreate_assistants, # type: ignore async_create_assistants=async_create_assistants, # type: ignore
) # type: ignore ) # type: ignore
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(

View file

@ -2384,6 +2384,29 @@ class OpenAIAssistantsAPI(BaseLLM):
return response return response
# Create Assistant # Create Assistant
async def async_create_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
create_assistant_data: dict,
) -> Assistant:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.assistants.create(**create_assistant_data)
return response
def create_assistants( def create_assistants(
self, self,
api_key: Optional[str], api_key: Optional[str],
@ -2393,16 +2416,17 @@ class OpenAIAssistantsAPI(BaseLLM):
organization: Optional[str], organization: Optional[str],
create_assistant_data: dict, create_assistant_data: dict,
client=None, client=None,
acreate_assistants=None, async_create_assistants=None,
): ):
if acreate_assistants is not None and acreate_assistants == True: if async_create_assistants is not None and async_create_assistants == True:
return self.async_get_assistants( return self.async_create_assistants(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
client=client, client=client,
create_assistant_data=create_assistant_data,
) )
openai_client = self.get_openai_client( openai_client = self.get_openai_client(
api_key=api_key, api_key=api_key,