add create_assistants

This commit is contained in:
Ishaan Jaff 2024-07-09 08:51:42 -07:00
parent f6a0d6f33d
commit 9e22ce905e
2 changed files with 156 additions and 5 deletions

View file

@ -1,19 +1,27 @@
# What is this? # What is this?
## Main file for assistants API logic ## Main file for assistants API logic
from typing import Iterable import asyncio
import contextvars
import os
from functools import partial from functools import partial
import os, asyncio, contextvars from typing import Any, Dict, Iterable, List, Literal, Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from openai.types.beta.assistant import Assistant
import litellm import litellm
from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
from litellm import client from litellm import client
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import ( from litellm.utils import (
supports_httpx_timeout,
exception_type, exception_type,
get_llm_provider, get_llm_provider,
get_secret, get_secret,
supports_httpx_timeout,
) )
from ..llms.openai import OpenAIAssistantsAPI
from ..llms.azure import AzureAssistantsAPI from ..llms.azure import AzureAssistantsAPI
from ..llms.openai import OpenAIAssistantsAPI
from ..types.llms.openai import * from ..types.llms.openai import *
from ..types.router import * from ..types.router import *
from .utils import get_optional_params_add_message from .utils import get_optional_params_add_message
@ -178,6 +186,116 @@ def get_assistants(
return response return response
def create_assistants(
custom_llm_provider: Literal["openai", "azure"],
model: str,
name: Optional[str] = None,
description: Optional[str] = None,
instructions: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_resources: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, str]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
response_format: Optional[Union[str, Dict[str, str]]] = None,
client: Optional[Any] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
**kwargs,
) -> Assistant:
acreate_assistants: Optional[bool] = kwargs.pop("acreate_assistants", None)
if acreate_assistants is not None and not isinstance(acreate_assistants, bool):
raise ValueError(
"Invalid value passed in for acreate_assistants. Only bool or None allowed"
)
optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[Assistant] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
create_assistant_data = {
"model": model,
"name": name,
"description": description,
"instructions": instructions,
"tools": tools,
"tool_resources": tool_resources,
"metadata": metadata,
"temperature": temperature,
"top_p": top_p,
"response_format": response_format,
}
response = openai_assistants_api.create_assistants(
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
create_assistant_data=create_assistant_data,
client=client,
acreate_assistants=acreate_assistants, # type: ignore
) # type: ignore
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_assistants'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
if response is None:
raise litellm.exceptions.InternalServerError(
message="No response returned from 'create_assistants'",
model=model,
llm_provider=custom_llm_provider,
)
return response
### THREADS ### ### THREADS ###

View file

@ -2383,6 +2383,39 @@ class OpenAIAssistantsAPI(BaseLLM):
return response return response
# Create Assistant
def create_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
create_assistant_data: dict,
client=None,
acreate_assistants=None,
):
if acreate_assistants is not None and acreate_assistants == True:
return self.async_get_assistants(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.assistants.create(**create_assistant_data)
return response
### MESSAGES ### ### MESSAGES ###
async def a_add_message( async def a_add_message(