mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
* fix(utils.py): e2e azure tts cost tracking working moves tts response obj to include hidden params (allows for litellm call id, etc. to be sent in response headers) ; fixes spend_Tracking_utils logging payload to account for non-base model use-case Fixes https://github.com/BerriAI/litellm/issues/7223 * fix: fix linting errors * build(model_prices_and_context_window.json): add bedrock llama 3.3 Closes https://github.com/BerriAI/litellm/issues/7329 * fix(openai.py): fix return type for sync openai httpx response * test: update test * fix(spend_tracking_utils.py): fix if check * fix(spend_tracking_utils.py): fix if check * test: improve debugging for test * fix: fix import
2768 lines
96 KiB
Python
2768 lines
96 KiB
Python
import hashlib
|
|
import os
|
|
import types
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Coroutine,
|
|
Iterable,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
import httpx
|
|
import openai
|
|
from openai import AsyncOpenAI, OpenAI
|
|
from openai.types.beta.assistant_deleted import AssistantDeleted
|
|
from openai.types.file_deleted import FileDeleted
|
|
from pydantic import BaseModel
|
|
from typing_extensions import overload
|
|
|
|
import litellm
|
|
from litellm import LlmProviders
|
|
from litellm._logging import verbose_logger
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
|
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
|
|
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
|
from litellm.types.utils import EmbeddingResponse, ImageResponse, ModelResponse
|
|
from litellm.utils import (
|
|
CustomStreamWrapper,
|
|
ProviderConfigManager,
|
|
convert_to_model_response_object,
|
|
)
|
|
|
|
from ...types.llms.openai import *
|
|
from ..base import BaseLLM
|
|
from .chat.gpt_transformation import OpenAIGPTConfig
|
|
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
|
|
|
|
|
|
class MistralEmbeddingConfig:
|
|
"""
|
|
Reference: https://docs.mistral.ai/api/#operation/createEmbedding
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
) -> None:
|
|
locals_ = locals().copy()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params(self):
|
|
return [
|
|
"encoding_format",
|
|
]
|
|
|
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
|
for param, value in non_default_params.items():
|
|
if param == "encoding_format":
|
|
optional_params["encoding_format"] = value
|
|
return optional_params
|
|
|
|
|
|
class OpenAIConfig(BaseConfig):
|
|
"""
|
|
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
|
|
|
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
|
|
|
|
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
|
|
|
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
|
|
|
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
|
|
|
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
|
|
|
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. OpenAI has now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models.
|
|
|
|
- `max_completion_tokens` (integer or null): An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.
|
|
|
|
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
|
|
|
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
|
|
|
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
|
|
|
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
|
|
|
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
|
"""
|
|
|
|
frequency_penalty: Optional[int] = None
|
|
function_call: Optional[Union[str, dict]] = None
|
|
functions: Optional[list] = None
|
|
logit_bias: Optional[dict] = None
|
|
max_completion_tokens: Optional[int] = None
|
|
max_tokens: Optional[int] = None
|
|
n: Optional[int] = None
|
|
presence_penalty: Optional[int] = None
|
|
stop: Optional[Union[str, list]] = None
|
|
temperature: Optional[int] = None
|
|
top_p: Optional[int] = None
|
|
response_format: Optional[dict] = None
|
|
|
|
def __init__(
|
|
self,
|
|
frequency_penalty: Optional[int] = None,
|
|
function_call: Optional[Union[str, dict]] = None,
|
|
functions: Optional[list] = None,
|
|
logit_bias: Optional[dict] = None,
|
|
max_completion_tokens: Optional[int] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
presence_penalty: Optional[int] = None,
|
|
stop: Optional[Union[str, list]] = None,
|
|
temperature: Optional[int] = None,
|
|
top_p: Optional[int] = None,
|
|
response_format: Optional[dict] = None,
|
|
) -> None:
|
|
locals_ = locals().copy()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return super().get_config()
|
|
|
|
def get_supported_openai_params(self, model: str) -> list:
|
|
"""
|
|
This function returns the list
|
|
of supported openai parameters for a given OpenAI Model
|
|
|
|
- If O1 model, returns O1 supported params
|
|
- If gpt-audio model, returns gpt-audio supported params
|
|
- Else, returns gpt supported params
|
|
|
|
Args:
|
|
model (str): OpenAI model
|
|
|
|
Returns:
|
|
list: List of supported openai parameters
|
|
"""
|
|
if litellm.openAIO1Config.is_model_o1_reasoning_model(model=model):
|
|
return litellm.openAIO1Config.get_supported_openai_params(model=model)
|
|
elif litellm.openAIGPTAudioConfig.is_model_gpt_audio_model(model=model):
|
|
return litellm.openAIGPTAudioConfig.get_supported_openai_params(model=model)
|
|
else:
|
|
return litellm.openAIGPTConfig.get_supported_openai_params(model=model)
|
|
|
|
def _map_openai_params(
|
|
self, non_default_params: dict, optional_params: dict, model: str
|
|
) -> dict:
|
|
supported_openai_params = self.get_supported_openai_params(model)
|
|
for param, value in non_default_params.items():
|
|
if param in supported_openai_params:
|
|
optional_params[param] = value
|
|
return optional_params
|
|
|
|
def _transform_messages(
|
|
self, messages: List[AllMessageValues], model: str
|
|
) -> List[AllMessageValues]:
|
|
return messages
|
|
|
|
def map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
""" """
|
|
if litellm.openAIO1Config.is_model_o1_reasoning_model(model=model):
|
|
return litellm.openAIO1Config.map_openai_params(
|
|
non_default_params=non_default_params,
|
|
optional_params=optional_params,
|
|
model=model,
|
|
drop_params=drop_params,
|
|
)
|
|
elif litellm.openAIGPTAudioConfig.is_model_gpt_audio_model(model=model):
|
|
return litellm.openAIGPTAudioConfig.map_openai_params(
|
|
non_default_params=non_default_params,
|
|
optional_params=optional_params,
|
|
model=model,
|
|
drop_params=drop_params,
|
|
)
|
|
|
|
return litellm.openAIGPTConfig.map_openai_params(
|
|
non_default_params=non_default_params,
|
|
optional_params=optional_params,
|
|
model=model,
|
|
drop_params=drop_params,
|
|
)
|
|
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
|
) -> BaseLLMException:
|
|
return OpenAIError(
|
|
status_code=status_code,
|
|
message=error_message,
|
|
headers=headers,
|
|
)
|
|
|
|
def transform_request(
|
|
self,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
headers: dict,
|
|
) -> dict:
|
|
return {"model": model, "messages": messages, **optional_params}
|
|
|
|
def transform_response(
|
|
self,
|
|
model: str,
|
|
raw_response: httpx.Response,
|
|
model_response: ModelResponse,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
request_data: dict,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
encoding: Any,
|
|
api_key: Optional[str] = None,
|
|
json_mode: Optional[bool] = None,
|
|
) -> ModelResponse:
|
|
raise NotImplementedError(
|
|
"OpenAI handler does this transformation as it uses the OpenAI SDK."
|
|
)
|
|
|
|
def validate_environment(
|
|
self,
|
|
headers: dict,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
api_key: Optional[str] = None,
|
|
) -> dict:
|
|
raise NotImplementedError(
|
|
"OpenAI handler does this validation as it uses the OpenAI SDK."
|
|
)
|
|
|
|
|
|
class OpenAIChatCompletion(BaseLLM):
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def _get_openai_client(
|
|
self,
|
|
is_async: bool,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
|
max_retries: Optional[int] = 2,
|
|
organization: Optional[str] = None,
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
):
|
|
args = locals()
|
|
if client is None:
|
|
if not isinstance(max_retries, int):
|
|
raise OpenAIError(
|
|
status_code=422,
|
|
message="max retries must be an int. Passed in value: {}".format(
|
|
max_retries
|
|
),
|
|
)
|
|
# Creating a new OpenAI Client
|
|
# check in memory cache before creating a new one
|
|
# Convert the API key to bytes
|
|
hashed_api_key = None
|
|
if api_key is not None:
|
|
hash_object = hashlib.sha256(api_key.encode())
|
|
# Hexadecimal representation of the hash
|
|
hashed_api_key = hash_object.hexdigest()
|
|
|
|
_cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization},is_async={is_async}"
|
|
|
|
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key)
|
|
if _cached_client:
|
|
return _cached_client
|
|
if is_async:
|
|
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
|
|
api_key=api_key,
|
|
base_url=api_base,
|
|
http_client=litellm.aclient_session,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
)
|
|
else:
|
|
_new_client = OpenAI(
|
|
api_key=api_key,
|
|
base_url=api_base,
|
|
http_client=litellm.client_session,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
)
|
|
|
|
## SAVE CACHE KEY
|
|
litellm.in_memory_llm_clients_cache.set_cache(
|
|
key=_cache_key,
|
|
value=_new_client,
|
|
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
|
|
)
|
|
return _new_client
|
|
|
|
else:
|
|
return client
|
|
|
|
async def make_openai_chat_completion_request(
|
|
self,
|
|
openai_aclient: AsyncOpenAI,
|
|
data: dict,
|
|
timeout: Union[float, httpx.Timeout],
|
|
) -> Tuple[dict, BaseModel]:
|
|
"""
|
|
Helper to:
|
|
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
|
- call chat.completions.create by default
|
|
"""
|
|
try:
|
|
raw_response = (
|
|
await openai_aclient.chat.completions.with_raw_response.create(
|
|
**data, timeout=timeout
|
|
)
|
|
)
|
|
|
|
if hasattr(raw_response, "headers"):
|
|
headers = dict(raw_response.headers)
|
|
else:
|
|
headers = {}
|
|
response = raw_response.parse()
|
|
return headers, response
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def make_sync_openai_chat_completion_request(
|
|
self,
|
|
openai_client: OpenAI,
|
|
data: dict,
|
|
timeout: Union[float, httpx.Timeout],
|
|
) -> Tuple[dict, BaseModel]:
|
|
"""
|
|
Helper to:
|
|
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
|
- call chat.completions.create by default
|
|
"""
|
|
raw_response = None
|
|
try:
|
|
raw_response = openai_client.chat.completions.with_raw_response.create(
|
|
**data, timeout=timeout
|
|
)
|
|
|
|
if hasattr(raw_response, "headers"):
|
|
headers = dict(raw_response.headers)
|
|
else:
|
|
headers = {}
|
|
response = raw_response.parse()
|
|
return headers, response
|
|
except Exception as e:
|
|
if raw_response is not None:
|
|
raise Exception(
|
|
"error - {}, Received response - {}, Type of response - {}".format(
|
|
e, raw_response, type(raw_response)
|
|
)
|
|
)
|
|
else:
|
|
raise e
|
|
|
|
def mock_streaming(
|
|
self,
|
|
response: ModelResponse,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
model: str,
|
|
stream_options: Optional[dict] = None,
|
|
) -> CustomStreamWrapper:
|
|
completion_stream = MockResponseIterator(model_response=response)
|
|
streaming_response = CustomStreamWrapper(
|
|
completion_stream=completion_stream,
|
|
model=model,
|
|
custom_llm_provider="openai",
|
|
logging_obj=logging_obj,
|
|
stream_options=stream_options,
|
|
)
|
|
|
|
return streaming_response
|
|
|
|
def completion( # type: ignore # noqa: PLR0915
|
|
self,
|
|
model_response: ModelResponse,
|
|
timeout: Union[float, httpx.Timeout],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
logging_obj: Any,
|
|
model: Optional[str] = None,
|
|
messages: Optional[list] = None,
|
|
print_verbose: Optional[Callable] = None,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
acompletion: bool = False,
|
|
logger_fn=None,
|
|
headers: Optional[dict] = None,
|
|
custom_prompt_dict: dict = {},
|
|
client=None,
|
|
organization: Optional[str] = None,
|
|
custom_llm_provider: Optional[str] = None,
|
|
drop_params: Optional[bool] = None,
|
|
):
|
|
super().completion()
|
|
try:
|
|
fake_stream: bool = False
|
|
inference_params = optional_params.copy()
|
|
stream_options: Optional[dict] = inference_params.pop(
|
|
"stream_options", None
|
|
)
|
|
stream: Optional[bool] = inference_params.pop("stream", False)
|
|
if custom_llm_provider is not None and model is not None:
|
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
|
model=model, provider=LlmProviders(custom_llm_provider)
|
|
)
|
|
fake_stream = provider_config.should_fake_stream(
|
|
model=model, custom_llm_provider=custom_llm_provider, stream=stream
|
|
)
|
|
if headers:
|
|
inference_params["extra_headers"] = headers
|
|
if model is None or messages is None:
|
|
raise OpenAIError(status_code=422, message="Missing model or messages")
|
|
|
|
if not isinstance(timeout, float) and not isinstance(
|
|
timeout, httpx.Timeout
|
|
):
|
|
raise OpenAIError(
|
|
status_code=422,
|
|
message="Timeout needs to be a float or httpx.Timeout",
|
|
)
|
|
|
|
if custom_llm_provider is not None and custom_llm_provider != "openai":
|
|
model_response.model = f"{custom_llm_provider}/{model}"
|
|
|
|
if messages is not None and custom_llm_provider is not None:
|
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
|
model=model, provider=LlmProviders(custom_llm_provider)
|
|
)
|
|
if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
|
|
provider_config, OpenAIConfig
|
|
):
|
|
messages = provider_config._transform_messages(
|
|
messages=messages, model=model
|
|
)
|
|
|
|
for _ in range(
|
|
2
|
|
): # if call fails due to alternating messages, retry with reformatted message
|
|
data = OpenAIConfig().transform_request(
|
|
model=model,
|
|
messages=messages,
|
|
optional_params=inference_params,
|
|
litellm_params=litellm_params,
|
|
headers=headers or {},
|
|
)
|
|
try:
|
|
max_retries = data.pop("max_retries", 2)
|
|
if acompletion is True:
|
|
if stream is True and fake_stream is False:
|
|
return self.async_streaming(
|
|
logging_obj=logging_obj,
|
|
headers=headers,
|
|
data=data,
|
|
model=model,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
timeout=timeout,
|
|
client=client,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
drop_params=drop_params,
|
|
stream_options=stream_options,
|
|
)
|
|
else:
|
|
return self.acompletion(
|
|
data=data,
|
|
headers=headers,
|
|
model=model,
|
|
logging_obj=logging_obj,
|
|
model_response=model_response,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
timeout=timeout,
|
|
client=client,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
drop_params=drop_params,
|
|
fake_stream=fake_stream,
|
|
)
|
|
elif stream is True and fake_stream is False:
|
|
return self.streaming(
|
|
logging_obj=logging_obj,
|
|
headers=headers,
|
|
data=data,
|
|
model=model,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
timeout=timeout,
|
|
client=client,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
stream_options=stream_options,
|
|
)
|
|
else:
|
|
if not isinstance(max_retries, int):
|
|
raise OpenAIError(
|
|
status_code=422, message="max retries must be an int"
|
|
)
|
|
|
|
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
|
is_async=False,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=messages,
|
|
api_key=openai_client.api_key,
|
|
additional_args={
|
|
"headers": headers,
|
|
"api_base": openai_client._base_url._uri_reference,
|
|
"acompletion": acompletion,
|
|
"complete_input_dict": data,
|
|
},
|
|
)
|
|
|
|
headers, response = (
|
|
self.make_sync_openai_chat_completion_request(
|
|
openai_client=openai_client,
|
|
data=data,
|
|
timeout=timeout,
|
|
)
|
|
)
|
|
|
|
logging_obj.model_call_details["response_headers"] = headers
|
|
stringified_response = response.model_dump()
|
|
logging_obj.post_call(
|
|
input=messages,
|
|
api_key=api_key,
|
|
original_response=stringified_response,
|
|
additional_args={"complete_input_dict": data},
|
|
)
|
|
|
|
final_response_obj = convert_to_model_response_object(
|
|
response_object=stringified_response,
|
|
model_response_object=model_response,
|
|
_response_headers=headers,
|
|
)
|
|
if fake_stream is True:
|
|
return self.mock_streaming(
|
|
response=cast(ModelResponse, final_response_obj),
|
|
logging_obj=logging_obj,
|
|
model=model,
|
|
stream_options=stream_options,
|
|
)
|
|
|
|
return final_response_obj
|
|
except openai.UnprocessableEntityError as e:
|
|
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
|
|
if litellm.drop_params is True or drop_params is True:
|
|
inference_params = drop_params_from_unprocessable_entity_error(
|
|
e, inference_params
|
|
)
|
|
else:
|
|
raise e
|
|
# e.message
|
|
except Exception as e:
|
|
if print_verbose is not None:
|
|
print_verbose(f"openai.py: Received openai error - {str(e)}")
|
|
if (
|
|
"Conversation roles must alternate user/assistant" in str(e)
|
|
or "user and assistant roles should be alternating" in str(e)
|
|
) and messages is not None:
|
|
if print_verbose is not None:
|
|
print_verbose("openai.py: REFORMATS THE MESSAGE!")
|
|
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
|
new_messages = []
|
|
for i in range(len(messages) - 1): # type: ignore
|
|
new_messages.append(messages[i])
|
|
if messages[i]["role"] == messages[i + 1]["role"]:
|
|
if messages[i]["role"] == "user":
|
|
new_messages.append(
|
|
{"role": "assistant", "content": ""}
|
|
)
|
|
else:
|
|
new_messages.append({"role": "user", "content": ""})
|
|
new_messages.append(messages[-1])
|
|
messages = new_messages
|
|
elif (
|
|
"Last message must have role `user`" in str(e)
|
|
) and messages is not None:
|
|
new_messages = messages
|
|
new_messages.append({"role": "user", "content": ""})
|
|
messages = new_messages
|
|
elif (
|
|
"unknown field: parameter index is not a valid field" in str(e)
|
|
) and "tools" in data:
|
|
litellm.remove_index_from_tool_calls(
|
|
tool_calls=data["tools"], messages=messages
|
|
)
|
|
else:
|
|
raise e
|
|
except OpenAIError as e:
|
|
raise e
|
|
except Exception as e:
|
|
status_code = getattr(e, "status_code", 500)
|
|
error_headers = getattr(e, "headers", None)
|
|
error_text = getattr(e, "text", str(e))
|
|
error_response = getattr(e, "response", None)
|
|
if error_headers is None and error_response:
|
|
error_headers = getattr(error_response, "headers", None)
|
|
raise OpenAIError(
|
|
status_code=status_code, message=error_text, headers=error_headers
|
|
)
|
|
|
|
async def acompletion(
|
|
self,
|
|
data: dict,
|
|
model: str,
|
|
model_response: ModelResponse,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
timeout: Union[float, httpx.Timeout],
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
organization: Optional[str] = None,
|
|
client=None,
|
|
max_retries=None,
|
|
headers=None,
|
|
drop_params: Optional[bool] = None,
|
|
stream_options: Optional[dict] = None,
|
|
fake_stream: bool = False,
|
|
):
|
|
response = None
|
|
for _ in range(
|
|
2
|
|
): # if call fails due to alternating messages, retry with reformatted message
|
|
try:
|
|
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
|
is_async=True,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=data["messages"],
|
|
api_key=openai_aclient.api_key,
|
|
additional_args={
|
|
"headers": {
|
|
"Authorization": f"Bearer {openai_aclient.api_key}"
|
|
},
|
|
"api_base": openai_aclient._base_url._uri_reference,
|
|
"acompletion": True,
|
|
"complete_input_dict": data,
|
|
},
|
|
)
|
|
|
|
headers, response = await self.make_openai_chat_completion_request(
|
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
|
)
|
|
stringified_response = response.model_dump()
|
|
|
|
logging_obj.post_call(
|
|
input=data["messages"],
|
|
api_key=api_key,
|
|
original_response=stringified_response,
|
|
additional_args={"complete_input_dict": data},
|
|
)
|
|
logging_obj.model_call_details["response_headers"] = headers
|
|
final_response_obj = convert_to_model_response_object(
|
|
response_object=stringified_response,
|
|
model_response_object=model_response,
|
|
hidden_params={"headers": headers},
|
|
_response_headers=headers,
|
|
)
|
|
|
|
if fake_stream is True:
|
|
return self.mock_streaming(
|
|
response=cast(ModelResponse, final_response_obj),
|
|
logging_obj=logging_obj,
|
|
model=model,
|
|
stream_options=stream_options,
|
|
)
|
|
|
|
return final_response_obj
|
|
except openai.UnprocessableEntityError as e:
|
|
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
|
|
if litellm.drop_params is True or drop_params is True:
|
|
data = drop_params_from_unprocessable_entity_error(e, data)
|
|
else:
|
|
raise e
|
|
# e.message
|
|
except Exception as e:
|
|
exception_response = getattr(e, "response", None)
|
|
status_code = getattr(e, "status_code", 500)
|
|
error_headers = getattr(e, "headers", None)
|
|
if error_headers is None and exception_response:
|
|
error_headers = getattr(exception_response, "headers", None)
|
|
|
|
raise OpenAIError(
|
|
status_code=status_code, message=str(e), headers=error_headers
|
|
)
|
|
|
|
def streaming(
|
|
self,
|
|
logging_obj,
|
|
timeout: Union[float, httpx.Timeout],
|
|
data: dict,
|
|
model: str,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
organization: Optional[str] = None,
|
|
client=None,
|
|
max_retries=None,
|
|
headers=None,
|
|
stream_options: Optional[dict] = None,
|
|
):
|
|
data["stream"] = True
|
|
if stream_options is not None:
|
|
data["stream_options"] = stream_options
|
|
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
|
is_async=False,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=data["messages"],
|
|
api_key=api_key,
|
|
additional_args={
|
|
"headers": {"Authorization": f"Bearer {openai_client.api_key}"},
|
|
"api_base": openai_client._base_url._uri_reference,
|
|
"acompletion": False,
|
|
"complete_input_dict": data,
|
|
},
|
|
)
|
|
headers, response = self.make_sync_openai_chat_completion_request(
|
|
openai_client=openai_client,
|
|
data=data,
|
|
timeout=timeout,
|
|
)
|
|
|
|
logging_obj.model_call_details["response_headers"] = headers
|
|
streamwrapper = CustomStreamWrapper(
|
|
completion_stream=response,
|
|
model=model,
|
|
custom_llm_provider="openai",
|
|
logging_obj=logging_obj,
|
|
stream_options=data.get("stream_options", None),
|
|
_response_headers=headers,
|
|
)
|
|
return streamwrapper
|
|
|
|
async def async_streaming(
|
|
self,
|
|
timeout: Union[float, httpx.Timeout],
|
|
data: dict,
|
|
model: str,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
organization: Optional[str] = None,
|
|
client=None,
|
|
max_retries=None,
|
|
headers=None,
|
|
drop_params: Optional[bool] = None,
|
|
stream_options: Optional[dict] = None,
|
|
):
|
|
response = None
|
|
data["stream"] = True
|
|
if stream_options is not None:
|
|
data["stream_options"] = stream_options
|
|
for _ in range(2):
|
|
try:
|
|
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
|
is_async=True,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=data["messages"],
|
|
api_key=api_key,
|
|
additional_args={
|
|
"headers": headers,
|
|
"api_base": api_base,
|
|
"acompletion": True,
|
|
"complete_input_dict": data,
|
|
},
|
|
)
|
|
|
|
headers, response = await self.make_openai_chat_completion_request(
|
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
|
)
|
|
logging_obj.model_call_details["response_headers"] = headers
|
|
streamwrapper = CustomStreamWrapper(
|
|
completion_stream=response,
|
|
model=model,
|
|
custom_llm_provider="openai",
|
|
logging_obj=logging_obj,
|
|
stream_options=data.get("stream_options", None),
|
|
_response_headers=headers,
|
|
)
|
|
return streamwrapper
|
|
except openai.UnprocessableEntityError as e:
|
|
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
|
|
if litellm.drop_params is True or drop_params is True:
|
|
data = drop_params_from_unprocessable_entity_error(e, data)
|
|
else:
|
|
raise e
|
|
except (
|
|
Exception
|
|
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
|
|
|
|
if isinstance(e, OpenAIError):
|
|
raise e
|
|
|
|
error_headers = getattr(e, "headers", None)
|
|
status_code = getattr(e, "status_code", 500)
|
|
error_response = getattr(e, "response", None)
|
|
if error_headers is None and error_response:
|
|
error_headers = getattr(error_response, "headers", None)
|
|
if response is not None and hasattr(response, "text"):
|
|
raise OpenAIError(
|
|
status_code=status_code,
|
|
message=f"{str(e)}\n\nOriginal Response: {response.text}", # type: ignore
|
|
headers=error_headers,
|
|
)
|
|
else:
|
|
if type(e).__name__ == "ReadTimeout":
|
|
raise OpenAIError(
|
|
status_code=408,
|
|
message=f"{type(e).__name__}",
|
|
headers=error_headers,
|
|
)
|
|
elif hasattr(e, "status_code"):
|
|
raise OpenAIError(
|
|
status_code=getattr(e, "status_code", 500),
|
|
message=str(e),
|
|
headers=error_headers,
|
|
)
|
|
else:
|
|
raise OpenAIError(
|
|
status_code=500, message=f"{str(e)}", headers=error_headers
|
|
)
|
|
|
|
# Embedding
|
|
async def make_openai_embedding_request(
|
|
self,
|
|
openai_aclient: AsyncOpenAI,
|
|
data: dict,
|
|
timeout: Union[float, httpx.Timeout],
|
|
):
|
|
"""
|
|
Helper to:
|
|
- call embeddings.create.with_raw_response when litellm.return_response_headers is True
|
|
- call embeddings.create by default
|
|
"""
|
|
try:
|
|
raw_response = await openai_aclient.embeddings.with_raw_response.create(
|
|
**data, timeout=timeout
|
|
) # type: ignore
|
|
headers = dict(raw_response.headers)
|
|
response = raw_response.parse()
|
|
return headers, response
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def make_sync_openai_embedding_request(
|
|
self,
|
|
openai_client: OpenAI,
|
|
data: dict,
|
|
timeout: Union[float, httpx.Timeout],
|
|
):
|
|
"""
|
|
Helper to:
|
|
- call embeddings.create.with_raw_response when litellm.return_response_headers is True
|
|
- call embeddings.create by default
|
|
"""
|
|
try:
|
|
raw_response = openai_client.embeddings.with_raw_response.create(
|
|
**data, timeout=timeout
|
|
) # type: ignore
|
|
|
|
headers = dict(raw_response.headers)
|
|
response = raw_response.parse()
|
|
return headers, response
|
|
except Exception as e:
|
|
raise e
|
|
|
|
async def aembedding(
|
|
self,
|
|
input: list,
|
|
data: dict,
|
|
model_response: EmbeddingResponse,
|
|
timeout: float,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
client: Optional[AsyncOpenAI] = None,
|
|
max_retries=None,
|
|
):
|
|
try:
|
|
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
|
is_async=True,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
)
|
|
headers, response = await self.make_openai_embedding_request(
|
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
|
)
|
|
logging_obj.model_call_details["response_headers"] = headers
|
|
stringified_response = response.model_dump()
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=input,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=stringified_response,
|
|
)
|
|
returned_response: EmbeddingResponse = convert_to_model_response_object(
|
|
response_object=stringified_response,
|
|
model_response_object=model_response,
|
|
response_type="embedding",
|
|
_response_headers=headers,
|
|
) # type: ignore
|
|
return returned_response
|
|
except OpenAIError as e:
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=input,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=str(e),
|
|
)
|
|
raise e
|
|
except Exception as e:
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=input,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=str(e),
|
|
)
|
|
status_code = getattr(e, "status_code", 500)
|
|
error_headers = getattr(e, "headers", None)
|
|
error_text = getattr(e, "text", str(e))
|
|
error_response = getattr(e, "response", None)
|
|
if error_headers is None and error_response:
|
|
error_headers = getattr(error_response, "headers", None)
|
|
raise OpenAIError(
|
|
status_code=status_code, message=error_text, headers=error_headers
|
|
)
|
|
|
|
def embedding( # type: ignore
|
|
self,
|
|
model: str,
|
|
input: list,
|
|
timeout: float,
|
|
logging_obj,
|
|
model_response: EmbeddingResponse,
|
|
optional_params: dict,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
client=None,
|
|
aembedding=None,
|
|
max_retries: Optional[int] = None,
|
|
) -> EmbeddingResponse:
|
|
super().embedding()
|
|
try:
|
|
model = model
|
|
data = {"model": model, "input": input, **optional_params}
|
|
max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES
|
|
if not isinstance(max_retries, int):
|
|
raise OpenAIError(status_code=422, message="max retries must be an int")
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=input,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data, "api_base": api_base},
|
|
)
|
|
|
|
if aembedding is True:
|
|
return self.aembedding( # type: ignore
|
|
data=data,
|
|
input=input,
|
|
logging_obj=logging_obj,
|
|
model_response=model_response,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
timeout=timeout,
|
|
client=client,
|
|
max_retries=max_retries,
|
|
)
|
|
|
|
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
|
is_async=False,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
)
|
|
|
|
## embedding CALL
|
|
headers: Optional[Dict] = None
|
|
headers, sync_embedding_response = self.make_sync_openai_embedding_request(
|
|
openai_client=openai_client, data=data, timeout=timeout
|
|
) # type: ignore
|
|
|
|
## LOGGING
|
|
logging_obj.model_call_details["response_headers"] = headers
|
|
logging_obj.post_call(
|
|
input=input,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=sync_embedding_response,
|
|
)
|
|
response: EmbeddingResponse = convert_to_model_response_object(
|
|
response_object=sync_embedding_response.model_dump(),
|
|
model_response_object=model_response,
|
|
_response_headers=headers,
|
|
response_type="embedding",
|
|
) # type: ignore
|
|
return response
|
|
except OpenAIError as e:
|
|
raise e
|
|
except Exception as e:
|
|
status_code = getattr(e, "status_code", 500)
|
|
error_headers = getattr(e, "headers", None)
|
|
error_text = getattr(e, "text", str(e))
|
|
error_response = getattr(e, "response", None)
|
|
if error_headers is None and error_response:
|
|
error_headers = getattr(error_response, "headers", None)
|
|
raise OpenAIError(
|
|
status_code=status_code, message=error_text, headers=error_headers
|
|
)
|
|
|
|
async def aimage_generation(
|
|
self,
|
|
prompt: str,
|
|
data: dict,
|
|
model_response: ModelResponse,
|
|
timeout: float,
|
|
logging_obj: Any,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
client=None,
|
|
max_retries=None,
|
|
):
|
|
response = None
|
|
try:
|
|
|
|
openai_aclient = self._get_openai_client(
|
|
is_async=True,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore
|
|
stringified_response = response.model_dump()
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=prompt,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=stringified_response,
|
|
)
|
|
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
|
except Exception as e:
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=prompt,
|
|
api_key=api_key,
|
|
original_response=str(e),
|
|
)
|
|
raise e
|
|
|
|
def image_generation(
|
|
self,
|
|
model: Optional[str],
|
|
prompt: str,
|
|
timeout: float,
|
|
optional_params: dict,
|
|
logging_obj: Any,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
model_response: Optional[ImageResponse] = None,
|
|
client=None,
|
|
aimg_generation=None,
|
|
) -> ImageResponse:
|
|
data = {}
|
|
try:
|
|
model = model
|
|
data = {"model": model, "prompt": prompt, **optional_params}
|
|
max_retries = data.pop("max_retries", 2)
|
|
if not isinstance(max_retries, int):
|
|
raise OpenAIError(status_code=422, message="max retries must be an int")
|
|
|
|
if aimg_generation is True:
|
|
return self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
|
|
|
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
|
is_async=False,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
)
|
|
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=prompt,
|
|
api_key=openai_client.api_key,
|
|
additional_args={
|
|
"headers": {"Authorization": f"Bearer {openai_client.api_key}"},
|
|
"api_base": openai_client._base_url._uri_reference,
|
|
"acompletion": True,
|
|
"complete_input_dict": data,
|
|
},
|
|
)
|
|
|
|
## COMPLETION CALL
|
|
_response = openai_client.images.generate(**data, timeout=timeout) # type: ignore
|
|
|
|
response = _response.model_dump()
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=prompt,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=response,
|
|
)
|
|
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
|
except OpenAIError as e:
|
|
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=prompt,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=str(e),
|
|
)
|
|
raise e
|
|
except Exception as e:
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=prompt,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=str(e),
|
|
)
|
|
if hasattr(e, "status_code"):
|
|
raise OpenAIError(
|
|
status_code=getattr(e, "status_code", 500), message=str(e)
|
|
)
|
|
else:
|
|
raise OpenAIError(status_code=500, message=str(e))
|
|
|
|
def audio_speech(
|
|
self,
|
|
model: str,
|
|
input: str,
|
|
voice: str,
|
|
optional_params: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
organization: Optional[str],
|
|
project: Optional[str],
|
|
max_retries: int,
|
|
timeout: Union[float, httpx.Timeout],
|
|
aspeech: Optional[bool] = None,
|
|
client=None,
|
|
) -> HttpxBinaryResponseContent:
|
|
|
|
if aspeech is not None and aspeech is True:
|
|
return self.async_audio_speech(
|
|
model=model,
|
|
input=input,
|
|
voice=voice,
|
|
optional_params=optional_params,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
organization=organization,
|
|
project=project,
|
|
max_retries=max_retries,
|
|
timeout=timeout,
|
|
client=client,
|
|
) # type: ignore
|
|
|
|
openai_client = self._get_openai_client(
|
|
is_async=False,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
)
|
|
|
|
response = cast(OpenAI, openai_client).audio.speech.create(
|
|
model=model,
|
|
voice=voice, # type: ignore
|
|
input=input,
|
|
**optional_params,
|
|
)
|
|
return HttpxBinaryResponseContent(response=response.response)
|
|
|
|
async def async_audio_speech(
|
|
self,
|
|
model: str,
|
|
input: str,
|
|
voice: str,
|
|
optional_params: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
organization: Optional[str],
|
|
project: Optional[str],
|
|
max_retries: int,
|
|
timeout: Union[float, httpx.Timeout],
|
|
client=None,
|
|
) -> HttpxBinaryResponseContent:
|
|
|
|
openai_client = cast(
|
|
AsyncOpenAI,
|
|
self._get_openai_client(
|
|
is_async=True,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
),
|
|
)
|
|
|
|
response = await openai_client.audio.speech.create(
|
|
model=model,
|
|
voice=voice, # type: ignore
|
|
input=input,
|
|
**optional_params,
|
|
)
|
|
|
|
return HttpxBinaryResponseContent(response=response.response)
|
|
|
|
async def ahealth_check(
|
|
self,
|
|
model: Optional[str],
|
|
api_key: Optional[str],
|
|
timeout: float,
|
|
mode: str,
|
|
messages: Optional[list] = None,
|
|
input: Optional[list] = None,
|
|
prompt: Optional[str] = None,
|
|
organization: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
):
|
|
client = AsyncOpenAI(
|
|
api_key=api_key,
|
|
timeout=timeout,
|
|
organization=organization,
|
|
base_url=api_base,
|
|
)
|
|
if model is None and mode != "image_generation":
|
|
raise Exception("model is not set")
|
|
|
|
completion = None
|
|
|
|
if mode == "completion":
|
|
completion = await client.completions.with_raw_response.create(
|
|
model=model, # type: ignore
|
|
prompt=prompt, # type: ignore
|
|
)
|
|
elif mode == "chat":
|
|
if messages is None:
|
|
raise Exception("messages is not set")
|
|
completion = await client.chat.completions.with_raw_response.create(
|
|
model=model, # type: ignore
|
|
messages=messages, # type: ignore
|
|
)
|
|
elif mode == "embedding":
|
|
if input is None:
|
|
raise Exception("input is not set")
|
|
completion = await client.embeddings.with_raw_response.create(
|
|
model=model, # type: ignore
|
|
input=input, # type: ignore
|
|
)
|
|
elif mode == "image_generation":
|
|
if prompt is None:
|
|
raise Exception("prompt is not set")
|
|
completion = await client.images.with_raw_response.generate(
|
|
model=model, # type: ignore
|
|
prompt=prompt, # type: ignore
|
|
)
|
|
elif mode == "audio_transcription":
|
|
# Get the current directory of the file being run
|
|
pwd = os.path.dirname(os.path.realpath(__file__))
|
|
file_path = os.path.join(
|
|
pwd, "../../../tests/gettysburg.wav"
|
|
) # proxy address
|
|
audio_file = open(file_path, "rb")
|
|
completion = await client.audio.transcriptions.with_raw_response.create(
|
|
file=audio_file,
|
|
model=model, # type: ignore
|
|
prompt=prompt, # type: ignore
|
|
)
|
|
elif mode == "audio_speech":
|
|
# Get the current directory of the file being run
|
|
completion = await client.audio.speech.with_raw_response.create(
|
|
model=model, # type: ignore
|
|
input=prompt, # type: ignore
|
|
voice="alloy",
|
|
)
|
|
else:
|
|
raise ValueError("mode not set, passed in mode: " + mode)
|
|
response = {}
|
|
|
|
if completion is None or not hasattr(completion, "headers"):
|
|
raise Exception("invalid completion response")
|
|
|
|
if (
|
|
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
|
|
): # not provided for dall-e requests
|
|
response["x-ratelimit-remaining-requests"] = completion.headers[
|
|
"x-ratelimit-remaining-requests"
|
|
]
|
|
|
|
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
|
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
|
"x-ratelimit-remaining-tokens"
|
|
]
|
|
return response
|
|
|
|
|
|
class OpenAIFilesAPI(BaseLLM):
|
|
"""
|
|
OpenAI methods to support for batches
|
|
- create_file()
|
|
- retrieve_file()
|
|
- list_files()
|
|
- delete_file()
|
|
- file_content()
|
|
- update_file()
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def get_openai_client(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
_is_async: bool = False,
|
|
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
|
|
received_args = locals()
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
|
|
if client is None:
|
|
data = {}
|
|
for k, v in received_args.items():
|
|
if k == "self" or k == "client" or k == "_is_async":
|
|
pass
|
|
elif k == "api_base" and v is not None:
|
|
data["base_url"] = v
|
|
elif v is not None:
|
|
data[k] = v
|
|
if _is_async is True:
|
|
openai_client = AsyncOpenAI(**data)
|
|
else:
|
|
openai_client = OpenAI(**data) # type: ignore
|
|
else:
|
|
openai_client = client
|
|
|
|
return openai_client
|
|
|
|
async def acreate_file(
|
|
self,
|
|
create_file_data: CreateFileRequest,
|
|
openai_client: AsyncOpenAI,
|
|
) -> FileObject:
|
|
response = await openai_client.files.create(**create_file_data)
|
|
return response
|
|
|
|
def create_file(
|
|
self,
|
|
_is_async: bool,
|
|
create_file_data: CreateFileRequest,
|
|
api_base: str,
|
|
api_key: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.acreate_file( # type: ignore
|
|
create_file_data=create_file_data, openai_client=openai_client
|
|
)
|
|
response = openai_client.files.create(**create_file_data)
|
|
return response
|
|
|
|
async def afile_content(
|
|
self,
|
|
file_content_request: FileContentRequest,
|
|
openai_client: AsyncOpenAI,
|
|
) -> HttpxBinaryResponseContent:
|
|
response = await openai_client.files.content(**file_content_request)
|
|
return HttpxBinaryResponseContent(response=response.response)
|
|
|
|
def file_content(
|
|
self,
|
|
_is_async: bool,
|
|
file_content_request: FileContentRequest,
|
|
api_base: str,
|
|
api_key: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
) -> Union[
|
|
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
|
]:
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.afile_content( # type: ignore
|
|
file_content_request=file_content_request,
|
|
openai_client=openai_client,
|
|
)
|
|
response = cast(OpenAI, openai_client).files.content(**file_content_request)
|
|
|
|
return HttpxBinaryResponseContent(response=response.response)
|
|
|
|
async def aretrieve_file(
|
|
self,
|
|
file_id: str,
|
|
openai_client: AsyncOpenAI,
|
|
) -> FileObject:
|
|
response = await openai_client.files.retrieve(file_id=file_id)
|
|
return response
|
|
|
|
def retrieve_file(
|
|
self,
|
|
_is_async: bool,
|
|
file_id: str,
|
|
api_base: str,
|
|
api_key: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
):
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.aretrieve_file( # type: ignore
|
|
file_id=file_id,
|
|
openai_client=openai_client,
|
|
)
|
|
response = openai_client.files.retrieve(file_id=file_id)
|
|
|
|
return response
|
|
|
|
async def adelete_file(
|
|
self,
|
|
file_id: str,
|
|
openai_client: AsyncOpenAI,
|
|
) -> FileDeleted:
|
|
response = await openai_client.files.delete(file_id=file_id)
|
|
return response
|
|
|
|
def delete_file(
|
|
self,
|
|
_is_async: bool,
|
|
file_id: str,
|
|
api_base: str,
|
|
api_key: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
):
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.adelete_file( # type: ignore
|
|
file_id=file_id,
|
|
openai_client=openai_client,
|
|
)
|
|
response = openai_client.files.delete(file_id=file_id)
|
|
|
|
return response
|
|
|
|
async def alist_files(
|
|
self,
|
|
openai_client: AsyncOpenAI,
|
|
purpose: Optional[str] = None,
|
|
):
|
|
if isinstance(purpose, str):
|
|
response = await openai_client.files.list(purpose=purpose)
|
|
else:
|
|
response = await openai_client.files.list()
|
|
return response
|
|
|
|
def list_files(
|
|
self,
|
|
_is_async: bool,
|
|
api_base: str,
|
|
api_key: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
purpose: Optional[str] = None,
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
):
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.alist_files( # type: ignore
|
|
purpose=purpose,
|
|
openai_client=openai_client,
|
|
)
|
|
|
|
if isinstance(purpose, str):
|
|
response = openai_client.files.list(purpose=purpose)
|
|
else:
|
|
response = openai_client.files.list()
|
|
|
|
return response
|
|
|
|
|
|
class OpenAIBatchesAPI(BaseLLM):
|
|
"""
|
|
OpenAI methods to support for batches
|
|
- create_batch()
|
|
- retrieve_batch()
|
|
- cancel_batch()
|
|
- list_batch()
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def get_openai_client(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
_is_async: bool = False,
|
|
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
|
|
received_args = locals()
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
|
|
if client is None:
|
|
data = {}
|
|
for k, v in received_args.items():
|
|
if k == "self" or k == "client" or k == "_is_async":
|
|
pass
|
|
elif k == "api_base" and v is not None:
|
|
data["base_url"] = v
|
|
elif v is not None:
|
|
data[k] = v
|
|
if _is_async is True:
|
|
openai_client = AsyncOpenAI(**data)
|
|
else:
|
|
openai_client = OpenAI(**data) # type: ignore
|
|
else:
|
|
openai_client = client
|
|
|
|
return openai_client
|
|
|
|
async def acreate_batch(
|
|
self,
|
|
create_batch_data: CreateBatchRequest,
|
|
openai_client: AsyncOpenAI,
|
|
) -> Batch:
|
|
response = await openai_client.batches.create(**create_batch_data)
|
|
return response
|
|
|
|
def create_batch(
|
|
self,
|
|
_is_async: bool,
|
|
create_batch_data: CreateBatchRequest,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.acreate_batch( # type: ignore
|
|
create_batch_data=create_batch_data, openai_client=openai_client
|
|
)
|
|
response = openai_client.batches.create(**create_batch_data)
|
|
return response
|
|
|
|
async def aretrieve_batch(
|
|
self,
|
|
retrieve_batch_data: RetrieveBatchRequest,
|
|
openai_client: AsyncOpenAI,
|
|
) -> Batch:
|
|
verbose_logger.debug("retrieving batch, args= %s", retrieve_batch_data)
|
|
response = await openai_client.batches.retrieve(**retrieve_batch_data)
|
|
return response
|
|
|
|
def retrieve_batch(
|
|
self,
|
|
_is_async: bool,
|
|
retrieve_batch_data: RetrieveBatchRequest,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI] = None,
|
|
):
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.aretrieve_batch( # type: ignore
|
|
retrieve_batch_data=retrieve_batch_data, openai_client=openai_client
|
|
)
|
|
response = openai_client.batches.retrieve(**retrieve_batch_data)
|
|
return response
|
|
|
|
def cancel_batch(
|
|
self,
|
|
_is_async: bool,
|
|
cancel_batch_data: CancelBatchRequest,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI] = None,
|
|
):
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
response = openai_client.batches.cancel(**cancel_batch_data)
|
|
return response
|
|
|
|
async def alist_batches(
|
|
self,
|
|
openai_client: AsyncOpenAI,
|
|
after: Optional[str] = None,
|
|
limit: Optional[int] = None,
|
|
):
|
|
verbose_logger.debug("listing batches, after= %s, limit= %s", after, limit)
|
|
response = await openai_client.batches.list(after=after, limit=limit) # type: ignore
|
|
return response
|
|
|
|
def list_batches(
|
|
self,
|
|
_is_async: bool,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
after: Optional[str] = None,
|
|
limit: Optional[int] = None,
|
|
client: Optional[OpenAI] = None,
|
|
):
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.alist_batches( # type: ignore
|
|
openai_client=openai_client, after=after, limit=limit
|
|
)
|
|
response = openai_client.batches.list(after=after, limit=limit) # type: ignore
|
|
return response
|
|
|
|
|
|
class OpenAIAssistantsAPI(BaseLLM):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def get_openai_client(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI] = None,
|
|
) -> OpenAI:
|
|
received_args = locals()
|
|
if client is None:
|
|
data = {}
|
|
for k, v in received_args.items():
|
|
if k == "self" or k == "client":
|
|
pass
|
|
elif k == "api_base" and v is not None:
|
|
data["base_url"] = v
|
|
elif v is not None:
|
|
data[k] = v
|
|
openai_client = OpenAI(**data) # type: ignore
|
|
else:
|
|
openai_client = client
|
|
|
|
return openai_client
|
|
|
|
def async_get_openai_client(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI] = None,
|
|
) -> AsyncOpenAI:
|
|
received_args = locals()
|
|
if client is None:
|
|
data = {}
|
|
for k, v in received_args.items():
|
|
if k == "self" or k == "client":
|
|
pass
|
|
elif k == "api_base" and v is not None:
|
|
data["base_url"] = v
|
|
elif v is not None:
|
|
data[k] = v
|
|
openai_client = AsyncOpenAI(**data) # type: ignore
|
|
else:
|
|
openai_client = client
|
|
|
|
return openai_client
|
|
|
|
### ASSISTANTS ###
|
|
|
|
async def async_get_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
) -> AsyncCursorPage[Assistant]:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_client.beta.assistants.list()
|
|
|
|
return response
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def get_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
aget_assistants: Literal[True],
|
|
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
|
|
...
|
|
|
|
@overload
|
|
def get_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI],
|
|
aget_assistants: Optional[Literal[False]],
|
|
) -> SyncCursorPage[Assistant]:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def get_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client=None,
|
|
aget_assistants=None,
|
|
):
|
|
if aget_assistants is not None and aget_assistants is True:
|
|
return self.async_get_assistants(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = openai_client.beta.assistants.list()
|
|
|
|
return response
|
|
|
|
# Create Assistant
|
|
async def async_create_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
create_assistant_data: dict,
|
|
) -> Assistant:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_client.beta.assistants.create(**create_assistant_data)
|
|
|
|
return response
|
|
|
|
def create_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
create_assistant_data: dict,
|
|
client=None,
|
|
async_create_assistants=None,
|
|
):
|
|
if async_create_assistants is not None and async_create_assistants is True:
|
|
return self.async_create_assistants(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
create_assistant_data=create_assistant_data,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = openai_client.beta.assistants.create(**create_assistant_data)
|
|
return response
|
|
|
|
# Delete Assistant
|
|
async def async_delete_assistant(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
assistant_id: str,
|
|
) -> AssistantDeleted:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_client.beta.assistants.delete(assistant_id=assistant_id)
|
|
|
|
return response
|
|
|
|
def delete_assistant(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
assistant_id: str,
|
|
client=None,
|
|
async_delete_assistants=None,
|
|
):
|
|
if async_delete_assistants is not None and async_delete_assistants is True:
|
|
return self.async_delete_assistant(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
assistant_id=assistant_id,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = openai_client.beta.assistants.delete(assistant_id=assistant_id)
|
|
return response
|
|
|
|
### MESSAGES ###
|
|
|
|
async def a_add_message(
|
|
self,
|
|
thread_id: str,
|
|
message_data: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI] = None,
|
|
) -> OpenAIMessage:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
|
|
thread_id, **message_data # type: ignore
|
|
)
|
|
|
|
response_obj: Optional[OpenAIMessage] = None
|
|
if getattr(thread_message, "status", None) is None:
|
|
thread_message.status = "completed"
|
|
response_obj = OpenAIMessage(**thread_message.dict())
|
|
else:
|
|
response_obj = OpenAIMessage(**thread_message.dict())
|
|
return response_obj
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def add_message(
|
|
self,
|
|
thread_id: str,
|
|
message_data: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
a_add_message: Literal[True],
|
|
) -> Coroutine[None, None, OpenAIMessage]:
|
|
...
|
|
|
|
@overload
|
|
def add_message(
|
|
self,
|
|
thread_id: str,
|
|
message_data: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI],
|
|
a_add_message: Optional[Literal[False]],
|
|
) -> OpenAIMessage:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def add_message(
|
|
self,
|
|
thread_id: str,
|
|
message_data: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client=None,
|
|
a_add_message: Optional[bool] = None,
|
|
):
|
|
if a_add_message is not None and a_add_message is True:
|
|
return self.a_add_message(
|
|
thread_id=thread_id,
|
|
message_data=message_data,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
|
|
thread_id, **message_data # type: ignore
|
|
)
|
|
|
|
response_obj: Optional[OpenAIMessage] = None
|
|
if getattr(thread_message, "status", None) is None:
|
|
thread_message.status = "completed"
|
|
response_obj = OpenAIMessage(**thread_message.dict())
|
|
else:
|
|
response_obj = OpenAIMessage(**thread_message.dict())
|
|
return response_obj
|
|
|
|
async def async_get_messages(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI] = None,
|
|
) -> AsyncCursorPage[OpenAIMessage]:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
|
|
|
|
return response
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def get_messages(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
aget_messages: Literal[True],
|
|
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
|
|
...
|
|
|
|
@overload
|
|
def get_messages(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI],
|
|
aget_messages: Optional[Literal[False]],
|
|
) -> SyncCursorPage[OpenAIMessage]:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def get_messages(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client=None,
|
|
aget_messages=None,
|
|
):
|
|
if aget_messages is not None and aget_messages is True:
|
|
return self.async_get_messages(
|
|
thread_id=thread_id,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
|
|
|
|
return response
|
|
|
|
### THREADS ###
|
|
|
|
async def async_create_thread(
|
|
self,
|
|
metadata: Optional[dict],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
) -> Thread:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
data = {}
|
|
if messages is not None:
|
|
data["messages"] = messages # type: ignore
|
|
if metadata is not None:
|
|
data["metadata"] = metadata # type: ignore
|
|
|
|
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
|
|
|
|
return Thread(**message_thread.dict())
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def create_thread(
|
|
self,
|
|
metadata: Optional[dict],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
client: Optional[AsyncOpenAI],
|
|
acreate_thread: Literal[True],
|
|
) -> Coroutine[None, None, Thread]:
|
|
...
|
|
|
|
@overload
|
|
def create_thread(
|
|
self,
|
|
metadata: Optional[dict],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
client: Optional[OpenAI],
|
|
acreate_thread: Optional[Literal[False]],
|
|
) -> Thread:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def create_thread(
|
|
self,
|
|
metadata: Optional[dict],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
client=None,
|
|
acreate_thread=None,
|
|
):
|
|
"""
|
|
Here's an example:
|
|
```
|
|
from litellm.llms.openai.openai import OpenAIAssistantsAPI, MessageData
|
|
|
|
# create thread
|
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
|
|
openai_api.create_thread(messages=[message])
|
|
```
|
|
"""
|
|
if acreate_thread is not None and acreate_thread is True:
|
|
return self.async_create_thread(
|
|
metadata=metadata,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
messages=messages,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
data = {}
|
|
if messages is not None:
|
|
data["messages"] = messages # type: ignore
|
|
if metadata is not None:
|
|
data["metadata"] = metadata # type: ignore
|
|
|
|
message_thread = openai_client.beta.threads.create(**data) # type: ignore
|
|
|
|
return Thread(**message_thread.dict())
|
|
|
|
async def async_get_thread(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
) -> Thread:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
|
|
|
|
return Thread(**response.dict())
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def get_thread(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
aget_thread: Literal[True],
|
|
) -> Coroutine[None, None, Thread]:
|
|
...
|
|
|
|
@overload
|
|
def get_thread(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI],
|
|
aget_thread: Optional[Literal[False]],
|
|
) -> Thread:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def get_thread(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client=None,
|
|
aget_thread=None,
|
|
):
|
|
if aget_thread is not None and aget_thread is True:
|
|
return self.async_get_thread(
|
|
thread_id=thread_id,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
|
|
|
|
return Thread(**response.dict())
|
|
|
|
def delete_thread(self):
|
|
pass
|
|
|
|
### RUNS ###
|
|
|
|
async def arun_thread(
|
|
self,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[object],
|
|
model: Optional[str],
|
|
stream: Optional[bool],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
) -> Run:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata,
|
|
model=model,
|
|
tools=tools,
|
|
)
|
|
|
|
return response
|
|
|
|
def async_run_thread_stream(
|
|
self,
|
|
client: AsyncOpenAI,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[object],
|
|
model: Optional[str],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
event_handler: Optional[AssistantEventHandler],
|
|
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
|
data = {
|
|
"thread_id": thread_id,
|
|
"assistant_id": assistant_id,
|
|
"additional_instructions": additional_instructions,
|
|
"instructions": instructions,
|
|
"metadata": metadata,
|
|
"model": model,
|
|
"tools": tools,
|
|
}
|
|
if event_handler is not None:
|
|
data["event_handler"] = event_handler
|
|
return client.beta.threads.runs.stream(**data) # type: ignore
|
|
|
|
def run_thread_stream(
|
|
self,
|
|
client: OpenAI,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[object],
|
|
model: Optional[str],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
event_handler: Optional[AssistantEventHandler],
|
|
) -> AssistantStreamManager[AssistantEventHandler]:
|
|
data = {
|
|
"thread_id": thread_id,
|
|
"assistant_id": assistant_id,
|
|
"additional_instructions": additional_instructions,
|
|
"instructions": instructions,
|
|
"metadata": metadata,
|
|
"model": model,
|
|
"tools": tools,
|
|
}
|
|
if event_handler is not None:
|
|
data["event_handler"] = event_handler
|
|
return client.beta.threads.runs.stream(**data) # type: ignore
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def run_thread(
|
|
self,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[object],
|
|
model: Optional[str],
|
|
stream: Optional[bool],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client,
|
|
arun_thread: Literal[True],
|
|
event_handler: Optional[AssistantEventHandler],
|
|
) -> Coroutine[None, None, Run]:
|
|
...
|
|
|
|
@overload
|
|
def run_thread(
|
|
self,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[object],
|
|
model: Optional[str],
|
|
stream: Optional[bool],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client,
|
|
arun_thread: Optional[Literal[False]],
|
|
event_handler: Optional[AssistantEventHandler],
|
|
) -> Run:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def run_thread(
|
|
self,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[object],
|
|
model: Optional[str],
|
|
stream: Optional[bool],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client=None,
|
|
arun_thread=None,
|
|
event_handler: Optional[AssistantEventHandler] = None,
|
|
):
|
|
if arun_thread is not None and arun_thread is True:
|
|
if stream is not None and stream is True:
|
|
_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
return self.async_run_thread_stream(
|
|
client=_client,
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata,
|
|
model=model,
|
|
tools=tools,
|
|
event_handler=event_handler,
|
|
)
|
|
return self.arun_thread(
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata,
|
|
model=model,
|
|
stream=stream,
|
|
tools=tools,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
if stream is not None and stream is True:
|
|
return self.run_thread_stream(
|
|
client=openai_client,
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata,
|
|
model=model,
|
|
tools=tools,
|
|
event_handler=event_handler,
|
|
)
|
|
|
|
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata,
|
|
model=model,
|
|
tools=tools,
|
|
)
|
|
|
|
return response
|