diff --git a/litellm/__init__.py b/litellm/__init__.py index dc0959dcf..5e745668b 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -605,7 +605,6 @@ all_embedding_models = ( ####### IMAGE GENERATION MODELS ################### openai_image_generation_models = ["dall-e-2", "dall-e-3"] - from .timeout import timeout from .utils import ( client, @@ -695,3 +694,4 @@ from .exceptions import ( from .budget_manager import BudgetManager from .proxy.proxy_cli import run_server from .router import Router +from .assistants.main import * diff --git a/litellm/assistants/main.py b/litellm/assistants/main.py new file mode 100644 index 000000000..25d2433d7 --- /dev/null +++ b/litellm/assistants/main.py @@ -0,0 +1,495 @@ +# What is this? +## Main file for assistants API logic +from typing import Iterable +import os +import litellm +from openai import OpenAI +from litellm import client +from litellm.utils import supports_httpx_timeout +from ..llms.openai import OpenAIAssistantsAPI +from ..types.llms.openai import * +from ..types.router import * + +####### ENVIRONMENT VARIABLES ################### +openai_assistants_api = OpenAIAssistantsAPI() + +### ASSISTANTS ### + + +def get_assistants( + custom_llm_provider: Literal["openai"], + client: Optional[OpenAI] = None, + **kwargs, +) -> SyncCursorPage[Assistant]: + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[SyncCursorPage[Assistant]] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.get_assistants( + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'get_assistants'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + + +### THREADS ### + + +def create_thread( + custom_llm_provider: Literal["openai"], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None, + metadata: Optional[dict] = None, + tool_resources: Optional[OpenAICreateThreadParamsToolResources] = None, + client: Optional[OpenAI] = None, + **kwargs, +) -> Thread: + """ + - get the llm provider + - if openai - route it there + - pass through relevant params + + ``` + from litellm import create_thread + + create_thread( + custom_llm_provider="openai", + ### OPTIONAL ### + messages = { + "role": "user", + "content": "Hello, what is AI?" + }, + { + "role": "user", + "content": "How does AI work? Explain it in simple terms." + }] + ) + ``` + """ + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[Thread] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.create_thread( + messages=messages, + metadata=metadata, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + + +def get_thread( + custom_llm_provider: Literal["openai"], + thread_id: str, + client: Optional[OpenAI] = None, + **kwargs, +) -> Thread: + """Get the thread object, given a thread_id""" + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[Thread] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.get_thread( + thread_id=thread_id, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'get_thread'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + + +### MESSAGES ### + + +def add_message( + custom_llm_provider: Literal["openai"], + thread_id: str, + role: Literal["user", "assistant"], + content: str, + attachments: Optional[List[Attachment]] = None, + metadata: Optional[dict] = None, + client: Optional[OpenAI] = None, + **kwargs, +) -> OpenAIMessage: + ### COMMON OBJECTS ### + message_data = MessageData( + role=role, content=content, attachments=attachments, metadata=metadata + ) + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[OpenAIMessage] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.add_message( + thread_id=thread_id, + message_data=message_data, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + + return response + + +def get_messages( + custom_llm_provider: Literal["openai"], + thread_id: str, + client: Optional[OpenAI] = None, + **kwargs, +) -> SyncCursorPage[OpenAIMessage]: + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[SyncCursorPage[OpenAIMessage]] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.get_messages( + thread_id=thread_id, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'get_messages'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + + return response + + +### RUNS ### + + +def run_thread( + custom_llm_provider: Literal["openai"], + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str] = None, + instructions: Optional[str] = None, + metadata: Optional[dict] = None, + model: Optional[str] = None, + stream: Optional[bool] = None, + tools: Optional[Iterable[AssistantToolParam]] = None, + client: Optional[OpenAI] = None, + **kwargs, +) -> Run: + """Run a given thread + assistant.""" + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[Run] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.run_thread( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + stream=stream, + tools=tools, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'run_thread'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 5a76605b3..f007507c9 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1,4 +1,13 @@ -from typing import Optional, Union, Any, BinaryIO +from typing import ( + Optional, + Union, + Any, + BinaryIO, + Literal, + Iterable, +) +from typing_extensions import override +from pydantic import BaseModel import types, time, json, traceback import httpx from .base import BaseLLM @@ -17,6 +26,7 @@ import aiohttp, requests import litellm from .prompt_templates.factory import prompt_factory, custom_prompt from openai import OpenAI, AsyncOpenAI +from ..types.llms.openai import * class OpenAIError(Exception): @@ -1236,3 +1246,223 @@ class OpenAITextCompletion(BaseLLM): async for transformed_chunk in streamwrapper: yield transformed_chunk + + +class OpenAIAssistantsAPI(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def get_openai_client( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI] = None, + ) -> OpenAI: + received_args = locals() + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client": + pass + elif k == "api_base" and v is not None: + data["base_url"] = v + elif v is not None: + data[k] = v + openai_client = OpenAI(**data) # type: ignore + else: + openai_client = client + + return openai_client + + ### ASSISTANTS ### + + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI], + ) -> SyncCursorPage[Assistant]: + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.assistants.list() + + return response + + ### MESSAGES ### + + def add_message( + self, + thread_id: str, + message_data: MessageData, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI] = None, + ) -> OpenAIMessage: + + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( + thread_id, **message_data + ) + + response_obj: Optional[OpenAIMessage] = None + if getattr(thread_message, "status", None) is None: + thread_message.status = "completed" + response_obj = OpenAIMessage(**thread_message.dict()) + else: + response_obj = OpenAIMessage(**thread_message.dict()) + return response_obj + + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI] = None, + ) -> SyncCursorPage[OpenAIMessage]: + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.threads.messages.list(thread_id=thread_id) + + return response + + ### THREADS ### + + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + ) -> Thread: + """ + Here's an example: + ``` + from litellm.llms.openai import OpenAIAssistantsAPI, MessageData + + # create thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} + openai_api.create_thread(messages=[message]) + ``` + """ + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + data = {} + if messages is not None: + data["messages"] = messages # type: ignore + if metadata is not None: + data["metadata"] = metadata # type: ignore + + message_thread = openai_client.beta.threads.create(**data) # type: ignore + + return Thread(**message_thread.dict()) + + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI], + ) -> Thread: + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.threads.retrieve(thread_id=thread_id) + + return Thread(**response.dict()) + + def delete_thread(self): + pass + + ### RUNS ### + + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI], + ) -> Run: + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.threads.runs.create_and_poll( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + ) + + return response diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py new file mode 100644 index 000000000..7f20a6df0 --- /dev/null +++ b/litellm/tests/test_assistants.py @@ -0,0 +1,102 @@ +# What is this? +## Unit Tests for OpenAI Assistants API +import sys, os, json +import traceback +from dotenv import load_dotenv + +load_dotenv() +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest, logging, asyncio +import litellm +from litellm import create_thread, get_thread +from litellm.llms.openai import ( + OpenAIAssistantsAPI, + MessageData, + Thread, + OpenAIMessage as Message, +) + +""" +V0 Scope: + +- Add Message -> `/v1/threads/{thread_id}/messages` +- Run Thread -> `/v1/threads/{thread_id}/run` +""" + + +def test_create_thread_litellm() -> Thread: + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + new_thread = create_thread( + custom_llm_provider="openai", + messages=[message], # type: ignore + ) + + assert isinstance( + new_thread, Thread + ), f"type of thread={type(new_thread)}. Expected Thread-type" + return new_thread + + +def test_get_thread_litellm(): + new_thread = test_create_thread_litellm() + + received_thread = get_thread( + custom_llm_provider="openai", + thread_id=new_thread.id, + ) + + assert isinstance( + received_thread, Thread + ), f"type of thread={type(received_thread)}. Expected Thread-type" + return new_thread + + +def test_add_message_litellm(): + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + new_thread = test_create_thread_litellm() + + # add message to thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + added_message = litellm.add_message( + thread_id=new_thread.id, custom_llm_provider="openai", **message + ) + + print(f"added message: {added_message}") + + assert isinstance(added_message, Message) + + +def test_run_thread_litellm(): + """ + - Get Assistants + - Create thread + - Create run w/ Assistants + Thread + """ + assistants = litellm.get_assistants(custom_llm_provider="openai") + + ## get the first assistant ### + assistant_id = assistants.data[0].id + + new_thread = test_create_thread_litellm() + + thread_id = new_thread.id + + # add message to thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + added_message = litellm.add_message( + thread_id=new_thread.id, custom_llm_provider="openai", **message + ) + + run = litellm.run_thread( + custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id + ) + + if run.status == "completed": + messages = litellm.get_messages( + thread_id=new_thread.id, custom_llm_provider="openai" + ) + assert isinstance(messages.data[0], Message) + else: + pytest.fail("An unexpected error occurred when running the thread") diff --git a/litellm/types/llms/__init__.py b/litellm/types/llms/__init__.py new file mode 100644 index 000000000..14952c9ae --- /dev/null +++ b/litellm/types/llms/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["openai"] + +from . import openai diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py new file mode 100644 index 000000000..1c60ad6db --- /dev/null +++ b/litellm/types/llms/openai.py @@ -0,0 +1,148 @@ +from typing import ( + Optional, + Union, + Any, + BinaryIO, + Literal, + Iterable, +) +from typing_extensions import override, Required +from pydantic import BaseModel + +from openai.types.beta.threads.message_content import MessageContent +from openai.types.beta.threads.message import Message as OpenAIMessage +from openai.types.beta.thread_create_params import ( + Message as OpenAICreateThreadParamsMessage, +) +from openai.types.beta.assistant_tool_param import AssistantToolParam +from openai.types.beta.threads.run import Run +from openai.types.beta.assistant import Assistant +from openai.pagination import SyncCursorPage + +from typing import TypedDict, List, Optional + + +class NotGiven: + """ + A sentinel singleton class used to distinguish omitted keyword arguments + from those passed in with the value None (which may have different behavior). + + For example: + + ```py + def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: + ... + + + get(timeout=1) # 1s timeout + get(timeout=None) # No timeout + get() # Default timeout behavior, which may not be statically known at the method definition. + ``` + """ + + def __bool__(self) -> Literal[False]: + return False + + @override + def __repr__(self) -> str: + return "NOT_GIVEN" + + +NOT_GIVEN = NotGiven() + + +class ToolResourcesCodeInterpreter(TypedDict, total=False): + file_ids: List[str] + """ + A list of [file](https://platform.openai.com/docs/api-reference/files) IDs made + available to the `code_interpreter` tool. There can be a maximum of 20 files + associated with the tool. + """ + + +class ToolResourcesFileSearchVectorStore(TypedDict, total=False): + file_ids: List[str] + """ + A list of [file](https://platform.openai.com/docs/api-reference/files) IDs to + add to the vector store. There can be a maximum of 10000 files in a vector + store. + """ + + metadata: object + """Set of 16 key-value pairs that can be attached to a vector store. + + This can be useful for storing additional information about the vector store in + a structured format. Keys can be a maximum of 64 characters long and values can + be a maxium of 512 characters long. + """ + + +class ToolResourcesFileSearch(TypedDict, total=False): + vector_store_ids: List[str] + """ + The + [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + attached to this thread. There can be a maximum of 1 vector store attached to + the thread. + """ + + vector_stores: Iterable[ToolResourcesFileSearchVectorStore] + """ + A helper to create a + [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + with file_ids and attach it to this thread. There can be a maximum of 1 vector + store attached to the thread. + """ + + +class OpenAICreateThreadParamsToolResources(TypedDict, total=False): + code_interpreter: ToolResourcesCodeInterpreter + + file_search: ToolResourcesFileSearch + + +class FileSearchToolParam(TypedDict, total=False): + type: Required[Literal["file_search"]] + """The type of tool being defined: `file_search`""" + + +class CodeInterpreterToolParam(TypedDict, total=False): + type: Required[Literal["code_interpreter"]] + """The type of tool being defined: `code_interpreter`""" + + +AttachmentTool = Union[CodeInterpreterToolParam, FileSearchToolParam] + + +class Attachment(TypedDict, total=False): + file_id: str + """The ID of the file to attach to the message.""" + + tools: Iterable[AttachmentTool] + """The tools to add this file to.""" + + +class MessageData(TypedDict): + role: Literal["user", "assistant"] + content: str + attachments: Optional[List[Attachment]] + metadata: Optional[dict] + + +class Thread(BaseModel): + id: str + """The identifier, which can be referenced in API endpoints.""" + + created_at: int + """The Unix timestamp (in seconds) for when the thread was created.""" + + metadata: Optional[object] = None + """Set of 16 key-value pairs that can be attached to an object. + + This can be useful for storing additional information about the object in a + structured format. Keys can be a maximum of 64 characters long and values can be + a maxium of 512 characters long. + """ + + object: Literal["thread"] + """The object type, which is always `thread`.""" diff --git a/litellm/types/router.py b/litellm/types/router.py index 1f5fb7103..dbfc549cd 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -97,8 +97,11 @@ class ModelInfo(BaseModel): setattr(self, key, value) -class LiteLLM_Params(BaseModel): - model: str +class GenericLiteLLMParams(BaseModel): + """ + LiteLLM Params without 'model' arg (used across completion / assistants api) + """ + custom_llm_provider: Optional[str] = None tpm: Optional[int] = None rpm: Optional[int] = None @@ -121,6 +124,66 @@ class LiteLLM_Params(BaseModel): aws_secret_access_key: Optional[str] = None aws_region_name: Optional[str] = None + def __init__( + self, + custom_llm_provider: Optional[str] = None, + max_retries: Optional[Union[int, str]] = None, + tpm: Optional[int] = None, + rpm: Optional[int] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + timeout: Optional[Union[float, str]] = None, # if str, pass in as os.environ/ + stream_timeout: Optional[Union[float, str]] = ( + None # timeout when making stream=True calls, if str, pass in as os.environ/ + ), + organization: Optional[str] = None, # for openai orgs + ## VERTEX AI ## + vertex_project: Optional[str] = None, + vertex_location: Optional[str] = None, + ## AWS BEDROCK / SAGEMAKER ## + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region_name: Optional[str] = None, + **params + ): + args = locals() + args.pop("max_retries", None) + args.pop("self", None) + args.pop("params", None) + args.pop("__class__", None) + if max_retries is not None and isinstance(max_retries, str): + max_retries = int(max_retries) # cast to int + super().__init__(max_retries=max_retries, **args, **params) + + class Config: + extra = "allow" + arbitrary_types_allowed = True + + def __contains__(self, key): + # Define custom behavior for the 'in' operator + return hasattr(self, key) + + def get(self, key, default=None): + # Custom .get() method to access attributes with a default value if the attribute doesn't exist + return getattr(self, key, default) + + def __getitem__(self, key): + # Allow dictionary-style access to attributes + return getattr(self, key) + + def __setitem__(self, key, value): + # Allow dictionary-style assignment of attributes + setattr(self, key, value) + + +class LiteLLM_Params(GenericLiteLLMParams): + """ + LiteLLM Params with 'model' requirement - used for completions + """ + + model: str + def __init__( self, model: str,