litellm/litellm/llms/openai.py
2024-08-27 12:14:23 -07:00

3566 lines
124 KiB
Python

import hashlib
import json
import os
import time
import traceback
import types
from typing import (
Any,
BinaryIO,
Callable,
Coroutine,
Iterable,
Literal,
Optional,
Union,
)
import httpx
import openai
from openai import AsyncOpenAI, OpenAI
from openai.types.beta.assistant_deleted import AssistantDeleted
from openai.types.file_deleted import FileDeleted
from pydantic import BaseModel
from typing_extensions import overload, override
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.utils import ProviderField
from litellm.utils import (
Choices,
CustomStreamWrapper,
Message,
ModelResponse,
TextCompletionResponse,
TranscriptionResponse,
Usage,
convert_to_model_response_object,
)
from ..types.llms.openai import *
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class OpenAIError(Exception):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.headers = headers
if request:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
if response:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class MistralConfig:
"""
Reference: https://docs.mistral.ai/api/
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
"""
temperature: Optional[int] = None
top_p: Optional[int] = None
max_tokens: Optional[int] = None
tools: Optional[list] = None
tool_choice: Optional[Literal["auto", "any", "none"]] = None
random_seed: Optional[int] = None
safe_prompt: Optional[bool] = None
response_format: Optional[dict] = None
stop: Optional[Union[str, list]] = None
def __init__(
self,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
max_tokens: Optional[int] = None,
tools: Optional[list] = None,
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
random_seed: Optional[int] = None,
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
stop: Optional[Union[str, list]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
"seed",
"stop",
"response_format",
]
def _map_tool_choice(self, tool_choice: str) -> str:
if tool_choice == "auto" or tool_choice == "none":
return tool_choice
elif tool_choice == "required":
return "any"
else: # openai 'tool_choice' object param not supported by Mistral API
return "any"
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop"] = value
if param == "tool_choice" and isinstance(value, str):
optional_params["tool_choice"] = self._map_tool_choice(
tool_choice=value
)
if param == "seed":
optional_params["extra_body"] = {"random_seed": value}
if param == "response_format":
optional_params["response_format"] = value
return optional_params
class MistralEmbeddingConfig:
"""
Reference: https://docs.mistral.ai/api/#operation/createEmbedding
"""
def __init__(
self,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"encoding_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "encoding_format":
optional_params["encoding_format"] = value
return optional_params
class AzureAIStudioConfig:
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Azure AI Studio API Key.",
field_value="zEJ...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Azure AI Studio API Base.",
field_value="https://Mistral-serverless.",
),
]
class DeepInfraConfig:
"""
Reference: https://deepinfra.com/docs/advanced/openai_api
The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
"""
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
tools: Optional[list] = None
tool_choice: Optional[Union[str, dict]] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
tools: Optional[list] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"frequency_penalty",
"function_call",
"functions",
"logit_bias",
"max_tokens",
"n",
"presence_penalty",
"stop",
"temperature",
"top_p",
"response_format",
"tools",
"tool_choice",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_openai_params = self.get_supported_openai_params()
for param, value in non_default_params.items():
if (
param == "temperature"
and value == 0
and model == "mistralai/Mistral-7B-Instruct-v0.1"
): # this model does no support temperature == 0
value = 0.0001 # close to 0
if param == "tool_choice":
if (
value != "auto" and value != "none"
): # https://deepinfra.com/docs/advanced/function_calling
## UNSUPPORTED TOOL CHOICE VALUE
if litellm.drop_params is True or drop_params is True:
value = None
else:
raise litellm.utils.UnsupportedParamsError(
message="Deepinfra doesn't support tool_choice={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
value
),
status_code=400,
)
if param in supported_openai_params:
if value is not None:
optional_params[param] = value
return optional_params
class GroqConfig:
"""
Reference: https://deepinfra.com/docs/advanced/openai_api
The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
"""
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
tools: Optional[list] = None
tool_choice: Optional[Union[str, dict]] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
tools: Optional[list] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params_stt(self):
return [
"prompt",
"response_format",
"temperature",
"language",
]
def get_supported_openai_response_formats_stt(self) -> List[str]:
return ["json", "verbose_json", "text"]
def map_openai_params_stt(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
response_formats = self.get_supported_openai_response_formats_stt()
for param, value in non_default_params.items():
if param == "response_format":
if value in response_formats:
optional_params[param] = value
else:
if litellm.drop_params is True or drop_params is True:
pass
else:
raise litellm.utils.UnsupportedParamsError(
message="Groq doesn't support response_format={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
value
),
status_code=400,
)
else:
optional_params[param] = value
return optional_params
class OpenAIConfig:
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> list:
base_params = [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"function_call",
"functions",
"max_retries",
"extra_headers",
"parallel_tool_calls",
] # works across all models
model_specific_params = []
if (
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
): # gpt-4 does not support 'response_format'
model_specific_params.append("response_format")
if (
model in litellm.open_ai_chat_completion_models
) or model in litellm.open_ai_text_completion_models:
model_specific_params.append(
"user"
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
return base_params + model_specific_params
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
) -> dict:
supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items():
if param in supported_openai_params:
optional_params[param] = value
return optional_params
class OpenAITextCompletionConfig:
"""
Reference: https://platform.openai.com/docs/api-reference/completions/create
The class `OpenAITextCompletionConfig` provides configuration for the OpenAI's text completion API interface. Below are the parameters:
- `best_of` (integer or null): This optional parameter generates server-side completions and returns the one with the highest log probability per token.
- `echo` (boolean or null): This optional parameter will echo back the prompt in addition to the completion.
- `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens.
- `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion.
- `n` (integer or null): This optional parameter sets how many completions to generate for each prompt.
- `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `suffix` (string or null): Defines the suffix that comes after a completion of inserted text.
- `temperature` (number or null): This optional parameter defines the sampling temperature to use.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
best_of: Optional[int] = None
echo: Optional[bool] = None
frequency_penalty: Optional[int] = None
logit_bias: Optional[dict] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
suffix: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
def __init__(
self,
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[int] = None,
logit_bias: Optional[dict] = None,
logprobs: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
suffix: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def convert_to_chat_model_response_object(
self,
response_object: Optional[TextCompletionResponse] = None,
model_response_object: Optional[ModelResponse] = None,
):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise ValueError("Error in response object format")
choice_list = []
for idx, choice in enumerate(response_object["choices"]):
message = Message(
content=choice["text"],
role="assistant",
)
choice = Choices(
finish_reason=choice["finish_reason"], index=idx, message=message
)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object:
setattr(model_response_object, "usage", response_object["usage"])
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
model_response_object._hidden_params["original_response"] = (
response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
)
return model_response_object
except Exception as e:
raise e
class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def _get_openai_client(
self,
is_async: bool,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
max_retries: Optional[int] = 2,
organization: Optional[str] = None,
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
):
args = locals()
if client is None:
if not isinstance(max_retries, int):
raise OpenAIError(
status_code=422,
message="max retries must be an int. Passed in value: {}".format(
max_retries
),
)
# Creating a new OpenAI Client
# check in memory cache before creating a new one
# Convert the API key to bytes
hashed_api_key = None
if api_key is not None:
hash_object = hashlib.sha256(api_key.encode())
# Hexadecimal representation of the hash
hashed_api_key = hash_object.hexdigest()
_cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization},is_async={is_async}"
if _cache_key in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key]
if is_async:
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
_new_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
## SAVE CACHE KEY
litellm.in_memory_llm_clients_cache[_cache_key] = _new_client
return _new_client
else:
return client
async def make_openai_chat_completion_request(
self,
openai_aclient: AsyncOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
) -> Tuple[dict, BaseModel]:
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
try:
raw_response = (
await openai_aclient.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
)
if hasattr(raw_response, "headers"):
headers = dict(raw_response.headers)
else:
headers = {}
response = raw_response.parse()
return headers, response
except OpenAIError as e:
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
def make_sync_openai_chat_completion_request(
self,
openai_client: OpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
) -> Tuple[dict, BaseModel]:
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
try:
raw_response = openai_client.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
if hasattr(raw_response, "headers"):
headers = dict(raw_response.headers)
else:
headers = {}
response = raw_response.parse()
return headers, response
except OpenAIError as e:
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
def completion(
self,
model_response: ModelResponse,
timeout: Union[float, httpx.Timeout],
optional_params: dict,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
logging_obj=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):
super().completion()
exception_mapping_worked = False
try:
if headers:
optional_params["extra_headers"] = headers
if model is None or messages is None:
raise OpenAIError(status_code=422, message="Missing model or messages")
if not isinstance(timeout, float) and not isinstance(
timeout, httpx.Timeout
):
raise OpenAIError(
status_code=422,
message="Timeout needs to be a float or httpx.Timeout",
)
if custom_llm_provider is not None and custom_llm_provider != "openai":
model_response.model = f"{custom_llm_provider}/{model}"
# process all OpenAI compatible provider logic here
if custom_llm_provider == "mistral":
# check if message content passed in as list, and not string
messages = prompt_factory(
model=model,
messages=messages,
custom_llm_provider=custom_llm_provider,
)
if custom_llm_provider == "perplexity" and messages is not None:
# check if messages.name is passed + supported, if not supported remove
messages = prompt_factory(
model=model,
messages=messages,
custom_llm_provider=custom_llm_provider,
)
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message
data = {"model": model, "messages": messages, **optional_params}
try:
max_retries = data.pop("max_retries", 2)
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(
logging_obj=logging_obj,
headers=headers,
data=data,
model=model,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
drop_params=drop_params,
)
else:
return self.acompletion(
data=data,
headers=headers,
logging_obj=logging_obj,
model_response=model_response,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
drop_params=drop_params,
)
elif optional_params.get("stream", False):
return self.streaming(
logging_obj=logging_obj,
headers=headers,
data=data,
model=model,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
else:
if not isinstance(max_retries, int):
raise OpenAIError(
status_code=422, message="max retries must be an int"
)
openai_client = self._get_openai_client(
is_async=False,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=openai_client.api_key,
additional_args={
"headers": headers,
"api_base": openai_client._base_url._uri_reference,
"acompletion": acompletion,
"complete_input_dict": data,
},
)
headers, response = (
self.make_sync_openai_chat_completion_request(
openai_client=openai_client,
data=data,
timeout=timeout,
)
)
logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump()
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
_response_headers=headers,
)
except openai.UnprocessableEntityError as e:
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
if litellm.drop_params is True or drop_params is True:
invalid_params: List[str] = []
if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore
detail = e.body.get("detail") # type: ignore
if (
isinstance(detail, List)
and len(detail) > 0
and isinstance(detail[0], dict)
):
for error_dict in detail:
if (
error_dict.get("loc")
and isinstance(error_dict.get("loc"), list)
and len(error_dict.get("loc")) == 2
):
invalid_params.append(error_dict["loc"][1])
new_data = {}
for k, v in optional_params.items():
if k not in invalid_params:
new_data[k] = v
optional_params = new_data
else:
raise e
# e.message
except Exception as e:
if print_verbose is not None:
print_verbose(f"openai.py: Received openai error - {str(e)}")
if (
"Conversation roles must alternate user/assistant" in str(e)
or "user and assistant roles should be alternating" in str(e)
) and messages is not None:
if print_verbose is not None:
print_verbose("openai.py: REFORMATS THE MESSAGE!")
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = []
for i in range(len(messages) - 1): # type: ignore
new_messages.append(messages[i])
if messages[i]["role"] == messages[i + 1]["role"]:
if messages[i]["role"] == "user":
new_messages.append(
{"role": "assistant", "content": ""}
)
else:
new_messages.append({"role": "user", "content": ""})
new_messages.append(messages[-1])
messages = new_messages
elif (
"Last message must have role `user`" in str(e)
) and messages is not None:
new_messages = messages
new_messages.append({"role": "user", "content": ""})
messages = new_messages
elif (
"unknown field: parameter index is not a valid field" in str(e)
) and "tools" in data:
litellm.remove_index_from_tool_calls(
tool_calls=data["tools"], messages=messages
)
else:
raise e
except OpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if hasattr(e, "status_code"):
raise OpenAIError(status_code=e.status_code, message=str(e))
else:
raise OpenAIError(status_code=500, message=traceback.format_exc())
async def acompletion(
self,
data: dict,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
timeout: Union[float, httpx.Timeout],
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
drop_params: Optional[bool] = None,
):
response = None
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message
try:
openai_aclient = self._get_openai_client(
is_async=True,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
## LOGGING
logging_obj.pre_call(
input=data["messages"],
api_key=openai_aclient.api_key,
additional_args={
"headers": {
"Authorization": f"Bearer {openai_aclient.api_key}"
},
"api_base": openai_aclient._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
headers, response = await self.make_openai_chat_completion_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
)
stringified_response = response.model_dump()
logging_obj.post_call(
input=data["messages"],
api_key=api_key,
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
logging_obj.model_call_details["response_headers"] = headers
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
hidden_params={"headers": headers},
_response_headers=headers,
)
except openai.UnprocessableEntityError as e:
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
if litellm.drop_params is True or drop_params is True:
invalid_params: List[str] = []
if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore
detail = e.body.get("detail") # type: ignore
if (
isinstance(detail, List)
and len(detail) > 0
and isinstance(detail[0], dict)
):
for error_dict in detail:
if (
error_dict.get("loc")
and isinstance(error_dict.get("loc"), list)
and len(error_dict.get("loc")) == 2
):
invalid_params.append(error_dict["loc"][1])
new_data = {}
for k, v in data.items():
if k not in invalid_params:
new_data[k] = v
data = new_data
else:
raise e
# e.message
except Exception as e:
raise e
def streaming(
self,
logging_obj,
timeout: Union[float, httpx.Timeout],
data: dict,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
):
openai_client = self._get_openai_client(
is_async=False,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
## LOGGING
logging_obj.pre_call(
input=data["messages"],
api_key=api_key,
additional_args={
"headers": {"Authorization": f"Bearer {openai_client.api_key}"},
"api_base": openai_client._base_url._uri_reference,
"acompletion": False,
"complete_input_dict": data,
},
)
headers, response = self.make_sync_openai_chat_completion_request(
openai_client=openai_client,
data=data,
timeout=timeout,
)
logging_obj.model_call_details["response_headers"] = headers
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
_response_headers=headers,
)
return streamwrapper
async def async_streaming(
self,
timeout: Union[float, httpx.Timeout],
data: dict,
model: str,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
drop_params: Optional[bool] = None,
):
response = None
for _ in range(2):
try:
openai_aclient = self._get_openai_client(
is_async=True,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
## LOGGING
logging_obj.pre_call(
input=data["messages"],
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"acompletion": True,
"complete_input_dict": data,
},
)
headers, response = await self.make_openai_chat_completion_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
)
logging_obj.model_call_details["response_headers"] = headers
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
_response_headers=headers,
)
return streamwrapper
except openai.UnprocessableEntityError as e:
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
if litellm.drop_params is True or drop_params is True:
invalid_params: List[str] = []
if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore
detail = e.body.get("detail") # type: ignore
if (
isinstance(detail, List)
and len(detail) > 0
and isinstance(detail[0], dict)
):
for error_dict in detail:
if (
error_dict.get("loc")
and isinstance(error_dict.get("loc"), list)
and len(error_dict.get("loc")) == 2
):
invalid_params.append(error_dict["loc"][1])
new_data = {}
for k, v in data.items():
if k not in invalid_params:
new_data[k] = v
data = new_data
else:
raise e
except (
Exception
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
if isinstance(e, OpenAIError):
raise e
if response is not None and hasattr(response, "text"):
raise OpenAIError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
)
else:
if type(e).__name__ == "ReadTimeout":
raise OpenAIError(
status_code=408, message=f"{type(e).__name__}"
)
elif hasattr(e, "status_code"):
raise OpenAIError(status_code=e.status_code, message=str(e))
else:
raise OpenAIError(status_code=500, message=f"{str(e)}")
# Embedding
async def make_openai_embedding_request(
self,
openai_aclient: AsyncOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call embeddings.create.with_raw_response when litellm.return_response_headers is True
- call embeddings.create by default
"""
try:
raw_response = await openai_aclient.embeddings.with_raw_response.create(
**data, timeout=timeout
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
def make_sync_openai_embedding_request(
self,
openai_client: OpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call embeddings.create.with_raw_response when litellm.return_response_headers is True
- call embeddings.create by default
"""
try:
raw_response = openai_client.embeddings.with_raw_response.create(
**data, timeout=timeout
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
async def aembedding(
self,
input: list,
data: dict,
model_response: litellm.utils.EmbeddingResponse,
timeout: float,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client: Optional[AsyncOpenAI] = None,
max_retries=None,
):
response = None
try:
openai_aclient = self._get_openai_client(
is_async=True,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
headers, response = await self.make_openai_embedding_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
)
logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
response_type="embedding",
_response_headers=headers,
) # type: ignore
except OpenAIError as e:
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
def embedding(
self,
model: str,
input: list,
timeout: float,
logging_obj,
model_response: litellm.utils.EmbeddingResponse,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
optional_params=None,
client=None,
aembedding=None,
):
super().embedding()
exception_mapping_worked = False
try:
model = model
data = {"model": model, "input": input, **optional_params}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data, "api_base": api_base},
)
if aembedding is True:
response = self.aembedding(
data=data,
input=input,
logging_obj=logging_obj,
model_response=model_response,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
)
return response
openai_client = self._get_openai_client(
is_async=False,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
## embedding CALL
headers: Optional[Dict] = None
headers, sync_embedding_response = self.make_sync_openai_embedding_request(
openai_client=openai_client, data=data, timeout=timeout
) # type: ignore
## LOGGING
logging_obj.model_call_details["response_headers"] = headers
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=sync_embedding_response,
)
return convert_to_model_response_object(
response_object=sync_embedding_response.model_dump(),
model_response_object=model_response,
_response_headers=headers,
response_type="embedding",
) # type: ignore
except OpenAIError as e:
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def aimage_generation(
self,
prompt: str,
data: dict,
model_response: ModelResponse,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
):
response = None
try:
openai_aclient = self._get_openai_client(
is_async=True,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e
def image_generation(
self,
model: Optional[str],
prompt: str,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
logging_obj=None,
optional_params=None,
client=None,
aimg_generation=None,
):
exception_mapping_worked = False
try:
model = model
data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
if aimg_generation == True:
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
openai_client = self._get_openai_client(
is_async=False,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=openai_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {openai_client.api_key}"},
"api_base": openai_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
## COMPLETION CALL
response = openai_client.images.generate(**data, timeout=timeout) # type: ignore
response = response.model_dump() # type: ignore
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
# return response
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e:
exception_mapping_worked = True
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
except Exception as e:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
if hasattr(e, "status_code"):
raise OpenAIError(status_code=e.status_code, message=str(e))
else:
raise OpenAIError(status_code=500, message=str(e))
# Audio Transcriptions
async def make_openai_audio_transcriptions_request(
self,
openai_aclient: AsyncOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
- call openai_aclient.audio.transcriptions.create by default
"""
try:
if litellm.return_response_headers is True:
raw_response = (
await openai_aclient.audio.transcriptions.with_raw_response.create(
**data, timeout=timeout
)
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
else:
response = await openai_aclient.audio.transcriptions.create(**data, timeout=timeout) # type: ignore
return None, response
except Exception as e:
raise e
def make_sync_openai_audio_transcriptions_request(
self,
openai_client: OpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
- call openai_aclient.audio.transcriptions.create by default
"""
try:
if litellm.return_response_headers is True:
raw_response = (
openai_client.audio.transcriptions.with_raw_response.create(
**data, timeout=timeout
)
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
else:
response = openai_client.audio.transcriptions.create(**data, timeout=timeout) # type: ignore
return None, response
except Exception as e:
raise e
def audio_transcriptions(
self,
model: str,
audio_file: BinaryIO,
optional_params: dict,
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
api_key: Optional[str],
api_base: Optional[str],
client=None,
logging_obj=None,
atranscription: bool = False,
):
data = {"model": model, "file": audio_file, **optional_params}
if atranscription is True:
return self.async_audio_transcriptions(
audio_file=audio_file,
data=data,
model_response=model_response,
timeout=timeout,
api_key=api_key,
api_base=api_base,
client=client,
max_retries=max_retries,
logging_obj=logging_obj,
)
openai_client = self._get_openai_client(
is_async=False,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
)
_, response = self.make_sync_openai_audio_transcriptions_request(
openai_client=openai_client,
data=data,
timeout=timeout,
)
if isinstance(response, BaseModel):
stringified_response = response.model_dump()
else:
stringified_response = TranscriptionResponse(text=response).model_dump()
## LOGGING
logging_obj.post_call(
input=audio_file.name,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response
async def async_audio_transcriptions(
self,
audio_file: BinaryIO,
data: dict,
model_response: TranscriptionResponse,
timeout: float,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
):
try:
openai_aclient = self._get_openai_client(
is_async=True,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
headers, response = await self.make_openai_audio_transcriptions_request(
openai_aclient=openai_aclient,
data=data,
timeout=timeout,
)
logging_obj.model_call_details["response_headers"] = headers
if isinstance(response, BaseModel):
stringified_response = response.model_dump()
else:
stringified_response = TranscriptionResponse(text=response).model_dump()
## LOGGING
logging_obj.post_call(
input=audio_file.name,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e
def audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
organization: Optional[str],
project: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
aspeech: Optional[bool] = None,
client=None,
) -> HttpxBinaryResponseContent:
if aspeech is not None and aspeech is True:
return self.async_audio_speech(
model=model,
input=input,
voice=voice,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
organization=organization,
project=project,
max_retries=max_retries,
timeout=timeout,
client=client,
) # type: ignore
openai_client = self._get_openai_client(
is_async=False,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = openai_client.audio.speech.create(
model=model,
voice=voice, # type: ignore
input=input,
**optional_params,
)
return response
async def async_audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
organization: Optional[str],
project: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
client=None,
) -> HttpxBinaryResponseContent:
openai_client = self._get_openai_client(
is_async=True,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = await openai_client.audio.speech.create(
model=model,
voice=voice, # type: ignore
input=input,
**optional_params,
)
return response
async def ahealth_check(
self,
model: Optional[str],
api_key: str,
timeout: float,
mode: str,
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
organization: Optional[str] = None,
api_base: Optional[str] = None,
):
client = AsyncOpenAI(
api_key=api_key,
timeout=timeout,
organization=organization,
base_url=api_base,
)
if model is None and mode != "image_generation":
raise Exception("model is not set")
completion = None
if mode == "completion":
completion = await client.completions.with_raw_response.create(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "chat":
if messages is None:
raise Exception("messages is not set")
completion = await client.chat.completions.with_raw_response.create(
model=model, # type: ignore
messages=messages, # type: ignore
)
elif mode == "embedding":
if input is None:
raise Exception("input is not set")
completion = await client.embeddings.with_raw_response.create(
model=model, # type: ignore
input=input, # type: ignore
)
elif mode == "image_generation":
if prompt is None:
raise Exception("prompt is not set")
completion = await client.images.with_raw_response.generate(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_transcription":
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
audio_file = open(file_path, "rb")
completion = await client.audio.transcriptions.with_raw_response.create(
file=audio_file,
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_speech":
# Get the current directory of the file being run
completion = await client.audio.speech.with_raw_response.create(
model=model, # type: ignore
input=prompt, # type: ignore
voice="alloy",
)
else:
raise ValueError("mode not set, passed in mode: " + mode)
response = {}
if completion is None or not hasattr(completion, "headers"):
raise Exception("invalid completion response")
if (
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
): # not provided for dall-e requests
response["x-ratelimit-remaining-requests"] = completion.headers[
"x-ratelimit-remaining-requests"
]
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens"
]
return response
class OpenAITextCompletion(BaseLLM):
_client_session: httpx.Client
def __init__(self) -> None:
super().__init__()
self._client_session = self.create_client_session()
def validate_environment(self, api_key):
headers = {
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def completion(
self,
model_response: ModelResponse,
api_key: str,
model: str,
messages: list,
timeout: float,
logging_obj: LiteLLMLoggingObj,
print_verbose: Optional[Callable] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
optional_params=None,
litellm_params=None,
logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None,
):
super().completion()
exception_mapping_worked = False
try:
if headers is None:
headers = self.validate_environment(api_key=api_key)
if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages")
if (
len(messages) > 0
and "content" in messages[0]
and type(messages[0]["content"]) == list
):
prompt = messages[0]["content"]
else:
prompt = [message["content"] for message in messages] # type: ignore
# don't send max retries to the api, if set
data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"complete_input_dict": data,
},
)
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries,
client=client,
organization=organization,
)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
elif optional_params.get("stream", False):
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries, # type: ignore
client=client,
organization=organization,
)
else:
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
raw_response = openai_client.completions.with_raw_response.create(**data) # type: ignore
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return TextCompletionResponse(**response_json)
except Exception as e:
raise e
async def acompletion(
self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
prompt: str,
api_key: str,
model: str,
timeout: float,
max_retries=None,
organization: Optional[str] = None,
client=None,
):
try:
if client is None:
openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
raw_response = await openai_aclient.completions.with_raw_response.create(
**data
)
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
response_obj = TextCompletionResponse(**response_json)
response_obj._hidden_params.original_response = json.dumps(response_json)
return response_obj
except Exception as e:
raise e
def streaming(
self,
logging_obj,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
api_base: Optional[str] = None,
max_retries=None,
client=None,
organization=None,
):
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
try:
response = openai_client.completions.with_raw_response.create(**data)
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
for chunk in streamwrapper:
yield chunk
async def async_streaming(
self,
logging_obj,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
api_base: Optional[str] = None,
client=None,
max_retries=None,
organization=None,
):
if client is None:
openai_client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
response = await openai_client.completions.with_raw_response.create(**data)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
class OpenAIFilesAPI(BaseLLM):
"""
OpenAI methods to support for batches
- create_file()
- retrieve_file()
- list_files()
- delete_file()
- file_content()
- update_file()
"""
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
received_args = locals()
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
if _is_async is True:
openai_client = AsyncOpenAI(**data)
else:
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_file(
self,
create_file_data: CreateFileRequest,
openai_client: AsyncOpenAI,
) -> FileObject:
response = await openai_client.files.create(**create_file_data)
return response
def create_file(
self,
_is_async: bool,
create_file_data: CreateFileRequest,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_file( # type: ignore
create_file_data=create_file_data, openai_client=openai_client
)
response = openai_client.files.create(**create_file_data)
return response
async def afile_content(
self,
file_content_request: FileContentRequest,
openai_client: AsyncOpenAI,
) -> HttpxBinaryResponseContent:
response = await openai_client.files.content(**file_content_request)
return response
def file_content(
self,
_is_async: bool,
file_content_request: FileContentRequest,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.afile_content( # type: ignore
file_content_request=file_content_request,
openai_client=openai_client,
)
response = openai_client.files.content(**file_content_request)
return response
async def aretrieve_file(
self,
file_id: str,
openai_client: AsyncOpenAI,
) -> FileObject:
response = await openai_client.files.retrieve(file_id=file_id)
return response
def retrieve_file(
self,
_is_async: bool,
file_id: str,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_file( # type: ignore
file_id=file_id,
openai_client=openai_client,
)
response = openai_client.files.retrieve(file_id=file_id)
return response
async def adelete_file(
self,
file_id: str,
openai_client: AsyncOpenAI,
) -> FileDeleted:
response = await openai_client.files.delete(file_id=file_id)
return response
def delete_file(
self,
_is_async: bool,
file_id: str,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.adelete_file( # type: ignore
file_id=file_id,
openai_client=openai_client,
)
response = openai_client.files.delete(file_id=file_id)
return response
async def alist_files(
self,
openai_client: AsyncOpenAI,
purpose: Optional[str] = None,
):
if isinstance(purpose, str):
response = await openai_client.files.list(purpose=purpose)
else:
response = await openai_client.files.list()
return response
def list_files(
self,
_is_async: bool,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
purpose: Optional[str] = None,
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.alist_files( # type: ignore
purpose=purpose,
openai_client=openai_client,
)
if isinstance(purpose, str):
response = openai_client.files.list(purpose=purpose)
else:
response = openai_client.files.list()
return response
class OpenAIBatchesAPI(BaseLLM):
"""
OpenAI methods to support for batches
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
received_args = locals()
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
if _is_async is True:
openai_client = AsyncOpenAI(**data)
else:
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_batch(
self,
create_batch_data: CreateBatchRequest,
openai_client: AsyncOpenAI,
) -> Batch:
response = await openai_client.batches.create(**create_batch_data)
return response
def create_batch(
self,
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, openai_client=openai_client
)
response = openai_client.batches.create(**create_batch_data)
return response
async def aretrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
openai_client: AsyncOpenAI,
) -> Batch:
verbose_logger.debug("retrieving batch, args= %s", retrieve_batch_data)
response = await openai_client.batches.retrieve(**retrieve_batch_data)
return response
def retrieve_batch(
self,
_is_async: bool,
retrieve_batch_data: RetrieveBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, openai_client=openai_client
)
response = openai_client.batches.retrieve(**retrieve_batch_data)
return response
def cancel_batch(
self,
_is_async: bool,
cancel_batch_data: CancelBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
response = openai_client.batches.cancel(**cancel_batch_data)
return response
async def alist_batches(
self,
openai_client: AsyncOpenAI,
after: Optional[str] = None,
limit: Optional[int] = None,
):
verbose_logger.debug("listing batches, after= %s, limit= %s", after, limit)
response = await openai_client.batches.list(after=after, limit=limit) # type: ignore
return response
def list_batches(
self,
_is_async: bool,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
after: Optional[str] = None,
limit: Optional[int] = None,
client: Optional[OpenAI] = None,
):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.alist_batches( # type: ignore
openai_client=openai_client, after=after, limit=limit
)
response = openai_client.batches.list(after=after, limit=limit) # type: ignore
return response
class OpenAIAssistantsAPI(BaseLLM):
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
def async_get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> AsyncOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = AsyncOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
### ASSISTANTS ###
async def async_get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> AsyncCursorPage[Assistant]:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.assistants.list()
return response
# fmt: off
@overload
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
aget_assistants: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
...
@overload
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
aget_assistants: Optional[Literal[False]],
) -> SyncCursorPage[Assistant]:
...
# fmt: on
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
aget_assistants=None,
):
if aget_assistants is not None and aget_assistants == True:
return self.async_get_assistants(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.assistants.list()
return response
# Create Assistant
async def async_create_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
create_assistant_data: dict,
) -> Assistant:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.assistants.create(**create_assistant_data)
return response
def create_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
create_assistant_data: dict,
client=None,
async_create_assistants=None,
):
if async_create_assistants is not None and async_create_assistants == True:
return self.async_create_assistants(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
create_assistant_data=create_assistant_data,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.assistants.create(**create_assistant_data)
return response
# Delete Assistant
async def async_delete_assistant(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
assistant_id: str,
) -> AssistantDeleted:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.assistants.delete(assistant_id=assistant_id)
return response
def delete_assistant(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
assistant_id: str,
client=None,
async_delete_assistants=None,
):
if async_delete_assistants is not None and async_delete_assistants == True:
return self.async_delete_assistant(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
assistant_id=assistant_id,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.assistants.delete(assistant_id=assistant_id)
return response
### MESSAGES ###
async def a_add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> OpenAIMessage:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data # type: ignore
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
# fmt: off
@overload
def add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
a_add_message: Literal[True],
) -> Coroutine[None, None, OpenAIMessage]:
...
@overload
def add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
a_add_message: Optional[Literal[False]],
) -> OpenAIMessage:
...
# fmt: on
def add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
a_add_message: Optional[bool] = None,
):
if a_add_message is not None and a_add_message == True:
return self.a_add_message(
thread_id=thread_id,
message_data=message_data,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data # type: ignore
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
async def async_get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> AsyncCursorPage[OpenAIMessage]:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
# fmt: off
@overload
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
aget_messages: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
...
@overload
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
aget_messages: Optional[Literal[False]],
) -> SyncCursorPage[OpenAIMessage]:
...
# fmt: on
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
aget_messages=None,
):
if aget_messages is not None and aget_messages == True:
return self.async_get_messages(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
### THREADS ###
async def async_create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
) -> Thread:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())
# fmt: off
@overload
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AsyncOpenAI],
acreate_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[OpenAI],
acreate_thread: Optional[Literal[False]],
) -> Thread:
...
# fmt: on
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client=None,
acreate_thread=None,
):
"""
Here's an example:
```
from litellm.llms.openai import OpenAIAssistantsAPI, MessageData
# create thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
openai_api.create_thread(messages=[message])
```
"""
if acreate_thread is not None and acreate_thread == True:
return self.async_create_thread(
metadata=metadata,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
messages=messages,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())
async def async_get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> Thread:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
# fmt: off
@overload
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
aget_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
aget_thread: Optional[Literal[False]],
) -> Thread:
...
# fmt: on
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
aget_thread=None,
):
if aget_thread is not None and aget_thread == True:
return self.async_get_thread(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
def delete_thread(self):
pass
### RUNS ###
async def arun_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> Run:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response
def async_run_thread_stream(
self,
client: AsyncOpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
def run_thread_stream(
self,
client: OpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AssistantStreamManager[AssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
# fmt: off
@overload
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client,
arun_thread: Literal[True],
event_handler: Optional[AssistantEventHandler],
) -> Coroutine[None, None, Run]:
...
@overload
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client,
arun_thread: Optional[Literal[False]],
event_handler: Optional[AssistantEventHandler],
) -> Run:
...
# fmt: on
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None,
):
if arun_thread is not None and arun_thread == True:
if stream is not None and stream == True:
_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
return self.async_run_thread_stream(
client=_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
if stream is not None and stream == True:
return self.run_thread_stream(
client=openai_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response