litellm-mirror/litellm/llms/OpenAI/openai.py
2024-09-05 09:12:27 -07:00

3414 lines
119 KiB
Python

import hashlib
import json
import os
import time
import traceback
import types
from typing import Any, Callable, Coroutine, Iterable, Literal, Optional, Union
import httpx
import openai
from openai import AsyncOpenAI, OpenAI
from openai.types.beta.assistant_deleted import AssistantDeleted
from openai.types.file_deleted import FileDeleted
from pydantic import BaseModel
from typing_extensions import overload, override
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.utils import ProviderField
from litellm.utils import (
Choices,
CustomStreamWrapper,
Message,
ModelResponse,
TextCompletionResponse,
Usage,
convert_to_model_response_object,
)
from ...types.llms.openai import *
from ..base import BaseLLM
from ..prompt_templates.factory import custom_prompt, prompt_factory
class OpenAIError(Exception):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.headers = headers
if request:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
if response:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class MistralConfig:
"""
Reference: https://docs.mistral.ai/api/
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
"""
temperature: Optional[int] = None
top_p: Optional[int] = None
max_tokens: Optional[int] = None
tools: Optional[list] = None
tool_choice: Optional[Literal["auto", "any", "none"]] = None
random_seed: Optional[int] = None
safe_prompt: Optional[bool] = None
response_format: Optional[dict] = None
stop: Optional[Union[str, list]] = None
def __init__(
self,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
max_tokens: Optional[int] = None,
tools: Optional[list] = None,
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
random_seed: Optional[int] = None,
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
stop: Optional[Union[str, list]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
"seed",
"stop",
"response_format",
]
def _map_tool_choice(self, tool_choice: str) -> str:
if tool_choice == "auto" or tool_choice == "none":
return tool_choice
elif tool_choice == "required":
return "any"
else: # openai 'tool_choice' object param not supported by Mistral API
return "any"
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop"] = value
if param == "tool_choice" and isinstance(value, str):
optional_params["tool_choice"] = self._map_tool_choice(
tool_choice=value
)
if param == "seed":
optional_params["extra_body"] = {"random_seed": value}
if param == "response_format":
optional_params["response_format"] = value
return optional_params
class MistralEmbeddingConfig:
"""
Reference: https://docs.mistral.ai/api/#operation/createEmbedding
"""
def __init__(
self,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"encoding_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "encoding_format":
optional_params["encoding_format"] = value
return optional_params
class AzureAIStudioConfig:
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Azure AI Studio API Key.",
field_value="zEJ...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Azure AI Studio API Base.",
field_value="https://Mistral-serverless.",
),
]
class DeepInfraConfig:
"""
Reference: https://deepinfra.com/docs/advanced/openai_api
The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
"""
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
tools: Optional[list] = None
tool_choice: Optional[Union[str, dict]] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
tools: Optional[list] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"frequency_penalty",
"function_call",
"functions",
"logit_bias",
"max_tokens",
"n",
"presence_penalty",
"stop",
"temperature",
"top_p",
"response_format",
"tools",
"tool_choice",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_openai_params = self.get_supported_openai_params()
for param, value in non_default_params.items():
if (
param == "temperature"
and value == 0
and model == "mistralai/Mistral-7B-Instruct-v0.1"
): # this model does no support temperature == 0
value = 0.0001 # close to 0
if param == "tool_choice":
if (
value != "auto" and value != "none"
): # https://deepinfra.com/docs/advanced/function_calling
## UNSUPPORTED TOOL CHOICE VALUE
if litellm.drop_params is True or drop_params is True:
value = None
else:
raise litellm.utils.UnsupportedParamsError(
message="Deepinfra doesn't support tool_choice={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
value
),
status_code=400,
)
if param in supported_openai_params:
if value is not None:
optional_params[param] = value
return optional_params
class GroqConfig:
"""
Reference: https://deepinfra.com/docs/advanced/openai_api
The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
"""
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
tools: Optional[list] = None
tool_choice: Optional[Union[str, dict]] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
tools: Optional[list] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params_stt(self):
return [
"prompt",
"response_format",
"temperature",
"language",
]
def get_supported_openai_response_formats_stt(self) -> List[str]:
return ["json", "verbose_json", "text"]
def map_openai_params_stt(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
response_formats = self.get_supported_openai_response_formats_stt()
for param, value in non_default_params.items():
if param == "response_format":
if value in response_formats:
optional_params[param] = value
else:
if litellm.drop_params is True or drop_params is True:
pass
else:
raise litellm.utils.UnsupportedParamsError(
message="Groq doesn't support response_format={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
value
),
status_code=400,
)
else:
optional_params[param] = value
return optional_params
class OpenAIConfig:
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> list:
base_params = [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"function_call",
"functions",
"max_retries",
"extra_headers",
"parallel_tool_calls",
] # works across all models
model_specific_params = []
if (
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
): # gpt-4 does not support 'response_format'
model_specific_params.append("response_format")
if (
model in litellm.open_ai_chat_completion_models
) or model in litellm.open_ai_text_completion_models:
model_specific_params.append(
"user"
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
return base_params + model_specific_params
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
) -> dict:
supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items():
if param in supported_openai_params:
optional_params[param] = value
return optional_params
class OpenAITextCompletionConfig:
"""
Reference: https://platform.openai.com/docs/api-reference/completions/create
The class `OpenAITextCompletionConfig` provides configuration for the OpenAI's text completion API interface. Below are the parameters:
- `best_of` (integer or null): This optional parameter generates server-side completions and returns the one with the highest log probability per token.
- `echo` (boolean or null): This optional parameter will echo back the prompt in addition to the completion.
- `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens.
- `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion.
- `n` (integer or null): This optional parameter sets how many completions to generate for each prompt.
- `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `suffix` (string or null): Defines the suffix that comes after a completion of inserted text.
- `temperature` (number or null): This optional parameter defines the sampling temperature to use.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
best_of: Optional[int] = None
echo: Optional[bool] = None
frequency_penalty: Optional[int] = None
logit_bias: Optional[dict] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
suffix: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
def __init__(
self,
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[int] = None,
logit_bias: Optional[dict] = None,
logprobs: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
suffix: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def convert_to_chat_model_response_object(
self,
response_object: Optional[TextCompletionResponse] = None,
model_response_object: Optional[ModelResponse] = None,
):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise ValueError("Error in response object format")
choice_list = []
for idx, choice in enumerate(response_object["choices"]):
message = Message(
content=choice["text"],
role="assistant",
)
choice = Choices(
finish_reason=choice["finish_reason"], index=idx, message=message
)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object:
setattr(model_response_object, "usage", response_object["usage"])
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
model_response_object._hidden_params["original_response"] = (
response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
)
return model_response_object
except Exception as e:
raise e
class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def _get_openai_client(
self,
is_async: bool,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
max_retries: Optional[int] = 2,
organization: Optional[str] = None,
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
):
args = locals()
if client is None:
if not isinstance(max_retries, int):
raise OpenAIError(
status_code=422,
message="max retries must be an int. Passed in value: {}".format(
max_retries
),
)
# Creating a new OpenAI Client
# check in memory cache before creating a new one
# Convert the API key to bytes
hashed_api_key = None
if api_key is not None:
hash_object = hashlib.sha256(api_key.encode())
# Hexadecimal representation of the hash
hashed_api_key = hash_object.hexdigest()
_cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization},is_async={is_async}"
if _cache_key in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key]
if is_async:
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
_new_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
## SAVE CACHE KEY
litellm.in_memory_llm_clients_cache[_cache_key] = _new_client
return _new_client
else:
return client
async def make_openai_chat_completion_request(
self,
openai_aclient: AsyncOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
) -> Tuple[dict, BaseModel]:
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
try:
raw_response = (
await openai_aclient.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
)
if hasattr(raw_response, "headers"):
headers = dict(raw_response.headers)
else:
headers = {}
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
def make_sync_openai_chat_completion_request(
self,
openai_client: OpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
) -> Tuple[dict, BaseModel]:
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
try:
raw_response = openai_client.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
if hasattr(raw_response, "headers"):
headers = dict(raw_response.headers)
else:
headers = {}
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
def completion(
self,
model_response: ModelResponse,
timeout: Union[float, httpx.Timeout],
optional_params: dict,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
logging_obj=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):
super().completion()
exception_mapping_worked = False
try:
if headers:
optional_params["extra_headers"] = headers
if model is None or messages is None:
raise OpenAIError(status_code=422, message="Missing model or messages")
if not isinstance(timeout, float) and not isinstance(
timeout, httpx.Timeout
):
raise OpenAIError(
status_code=422,
message="Timeout needs to be a float or httpx.Timeout",
)
if custom_llm_provider is not None and custom_llm_provider != "openai":
model_response.model = f"{custom_llm_provider}/{model}"
# process all OpenAI compatible provider logic here
if custom_llm_provider == "mistral":
# check if message content passed in as list, and not string
messages = prompt_factory(
model=model,
messages=messages,
custom_llm_provider=custom_llm_provider,
)
if custom_llm_provider == "perplexity" and messages is not None:
# check if messages.name is passed + supported, if not supported remove
messages = prompt_factory(
model=model,
messages=messages,
custom_llm_provider=custom_llm_provider,
)
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message
data = {"model": model, "messages": messages, **optional_params}
try:
max_retries = data.pop("max_retries", 2)
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(
logging_obj=logging_obj,
headers=headers,
data=data,
model=model,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
drop_params=drop_params,
)
else:
return self.acompletion(
data=data,
headers=headers,
logging_obj=logging_obj,
model_response=model_response,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
drop_params=drop_params,
)
elif optional_params.get("stream", False):
return self.streaming(
logging_obj=logging_obj,
headers=headers,
data=data,
model=model,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
else:
if not isinstance(max_retries, int):
raise OpenAIError(
status_code=422, message="max retries must be an int"
)
openai_client = self._get_openai_client(
is_async=False,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=openai_client.api_key,
additional_args={
"headers": headers,
"api_base": openai_client._base_url._uri_reference,
"acompletion": acompletion,
"complete_input_dict": data,
},
)
headers, response = (
self.make_sync_openai_chat_completion_request(
openai_client=openai_client,
data=data,
timeout=timeout,
)
)
logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump()
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
_response_headers=headers,
)
except openai.UnprocessableEntityError as e:
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
if litellm.drop_params is True or drop_params is True:
invalid_params: List[str] = []
if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore
detail = e.body.get("detail") # type: ignore
if (
isinstance(detail, List)
and len(detail) > 0
and isinstance(detail[0], dict)
):
for error_dict in detail:
if (
error_dict.get("loc")
and isinstance(error_dict.get("loc"), list)
and len(error_dict.get("loc")) == 2
):
invalid_params.append(error_dict["loc"][1])
new_data = {}
for k, v in optional_params.items():
if k not in invalid_params:
new_data[k] = v
optional_params = new_data
else:
raise e
# e.message
except Exception as e:
if print_verbose is not None:
print_verbose(f"openai.py: Received openai error - {str(e)}")
if (
"Conversation roles must alternate user/assistant" in str(e)
or "user and assistant roles should be alternating" in str(e)
) and messages is not None:
if print_verbose is not None:
print_verbose("openai.py: REFORMATS THE MESSAGE!")
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = []
for i in range(len(messages) - 1): # type: ignore
new_messages.append(messages[i])
if messages[i]["role"] == messages[i + 1]["role"]:
if messages[i]["role"] == "user":
new_messages.append(
{"role": "assistant", "content": ""}
)
else:
new_messages.append({"role": "user", "content": ""})
new_messages.append(messages[-1])
messages = new_messages
elif (
"Last message must have role `user`" in str(e)
) and messages is not None:
new_messages = messages
new_messages.append({"role": "user", "content": ""})
messages = new_messages
elif (
"unknown field: parameter index is not a valid field" in str(e)
) and "tools" in data:
litellm.remove_index_from_tool_calls(
tool_calls=data["tools"], messages=messages
)
else:
raise e
except OpenAIError as e:
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
async def acompletion(
self,
data: dict,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
timeout: Union[float, httpx.Timeout],
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
drop_params: Optional[bool] = None,
):
response = None
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message
try:
openai_aclient = self._get_openai_client(
is_async=True,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
## LOGGING
logging_obj.pre_call(
input=data["messages"],
api_key=openai_aclient.api_key,
additional_args={
"headers": {
"Authorization": f"Bearer {openai_aclient.api_key}"
},
"api_base": openai_aclient._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
headers, response = await self.make_openai_chat_completion_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
)
stringified_response = response.model_dump()
logging_obj.post_call(
input=data["messages"],
api_key=api_key,
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
logging_obj.model_call_details["response_headers"] = headers
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
hidden_params={"headers": headers},
_response_headers=headers,
)
except openai.UnprocessableEntityError as e:
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
if litellm.drop_params is True or drop_params is True:
invalid_params: List[str] = []
if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore
detail = e.body.get("detail") # type: ignore
if (
isinstance(detail, List)
and len(detail) > 0
and isinstance(detail[0], dict)
):
for error_dict in detail:
if (
error_dict.get("loc")
and isinstance(error_dict.get("loc"), list)
and len(error_dict.get("loc")) == 2
):
invalid_params.append(error_dict["loc"][1])
new_data = {}
for k, v in data.items():
if k not in invalid_params:
new_data[k] = v
data = new_data
else:
raise e
# e.message
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
def streaming(
self,
logging_obj,
timeout: Union[float, httpx.Timeout],
data: dict,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
):
openai_client = self._get_openai_client(
is_async=False,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
## LOGGING
logging_obj.pre_call(
input=data["messages"],
api_key=api_key,
additional_args={
"headers": {"Authorization": f"Bearer {openai_client.api_key}"},
"api_base": openai_client._base_url._uri_reference,
"acompletion": False,
"complete_input_dict": data,
},
)
headers, response = self.make_sync_openai_chat_completion_request(
openai_client=openai_client,
data=data,
timeout=timeout,
)
logging_obj.model_call_details["response_headers"] = headers
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
_response_headers=headers,
)
return streamwrapper
async def async_streaming(
self,
timeout: Union[float, httpx.Timeout],
data: dict,
model: str,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
drop_params: Optional[bool] = None,
):
response = None
for _ in range(2):
try:
openai_aclient = self._get_openai_client(
is_async=True,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
## LOGGING
logging_obj.pre_call(
input=data["messages"],
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"acompletion": True,
"complete_input_dict": data,
},
)
headers, response = await self.make_openai_chat_completion_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
)
logging_obj.model_call_details["response_headers"] = headers
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
_response_headers=headers,
)
return streamwrapper
except openai.UnprocessableEntityError as e:
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
if litellm.drop_params is True or drop_params is True:
invalid_params: List[str] = []
if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore
detail = e.body.get("detail") # type: ignore
if (
isinstance(detail, List)
and len(detail) > 0
and isinstance(detail[0], dict)
):
for error_dict in detail:
if (
error_dict.get("loc")
and isinstance(error_dict.get("loc"), list)
and len(error_dict.get("loc")) == 2
):
invalid_params.append(error_dict["loc"][1])
new_data = {}
for k, v in data.items():
if k not in invalid_params:
new_data[k] = v
data = new_data
else:
raise e
except (
Exception
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
if isinstance(e, OpenAIError):
raise e
error_headers = getattr(e, "headers", None)
if response is not None and hasattr(response, "text"):
raise OpenAIError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
headers=error_headers,
)
else:
if type(e).__name__ == "ReadTimeout":
raise OpenAIError(
status_code=408,
message=f"{type(e).__name__}",
headers=error_headers,
)
elif hasattr(e, "status_code"):
raise OpenAIError(
status_code=e.status_code,
message=str(e),
headers=error_headers,
)
else:
raise OpenAIError(
status_code=500, message=f"{str(e)}", headers=error_headers
)
# Embedding
async def make_openai_embedding_request(
self,
openai_aclient: AsyncOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call embeddings.create.with_raw_response when litellm.return_response_headers is True
- call embeddings.create by default
"""
try:
raw_response = await openai_aclient.embeddings.with_raw_response.create(
**data, timeout=timeout
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
def make_sync_openai_embedding_request(
self,
openai_client: OpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call embeddings.create.with_raw_response when litellm.return_response_headers is True
- call embeddings.create by default
"""
try:
raw_response = openai_client.embeddings.with_raw_response.create(
**data, timeout=timeout
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
async def aembedding(
self,
input: list,
data: dict,
model_response: litellm.utils.EmbeddingResponse,
timeout: float,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client: Optional[AsyncOpenAI] = None,
max_retries=None,
):
response = None
try:
openai_aclient = self._get_openai_client(
is_async=True,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
headers, response = await self.make_openai_embedding_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
)
logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
response_type="embedding",
_response_headers=headers,
) # type: ignore
except OpenAIError as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
def embedding(
self,
model: str,
input: list,
timeout: float,
logging_obj,
model_response: litellm.utils.EmbeddingResponse,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
optional_params=None,
client=None,
aembedding=None,
):
super().embedding()
exception_mapping_worked = False
try:
model = model
data = {"model": model, "input": input, **optional_params}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data, "api_base": api_base},
)
if aembedding is True:
response = self.aembedding(
data=data,
input=input,
logging_obj=logging_obj,
model_response=model_response,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
)
return response
openai_client = self._get_openai_client(
is_async=False,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
## embedding CALL
headers: Optional[Dict] = None
headers, sync_embedding_response = self.make_sync_openai_embedding_request(
openai_client=openai_client, data=data, timeout=timeout
) # type: ignore
## LOGGING
logging_obj.model_call_details["response_headers"] = headers
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=sync_embedding_response,
)
return convert_to_model_response_object(
response_object=sync_embedding_response.model_dump(),
model_response_object=model_response,
_response_headers=headers,
response_type="embedding",
) # type: ignore
except OpenAIError as e:
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def aimage_generation(
self,
prompt: str,
data: dict,
model_response: ModelResponse,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
):
response = None
try:
openai_aclient = self._get_openai_client(
is_async=True,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e
def image_generation(
self,
model: Optional[str],
prompt: str,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
logging_obj=None,
optional_params=None,
client=None,
aimg_generation=None,
):
exception_mapping_worked = False
try:
model = model
data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
if aimg_generation == True:
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
openai_client = self._get_openai_client(
is_async=False,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=openai_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {openai_client.api_key}"},
"api_base": openai_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
## COMPLETION CALL
response = openai_client.images.generate(**data, timeout=timeout) # type: ignore
response = response.model_dump() # type: ignore
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
# return response
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e:
exception_mapping_worked = True
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
except Exception as e:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
if hasattr(e, "status_code"):
raise OpenAIError(status_code=e.status_code, message=str(e))
else:
raise OpenAIError(status_code=500, message=str(e))
def audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
organization: Optional[str],
project: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
aspeech: Optional[bool] = None,
client=None,
) -> HttpxBinaryResponseContent:
if aspeech is not None and aspeech is True:
return self.async_audio_speech(
model=model,
input=input,
voice=voice,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
organization=organization,
project=project,
max_retries=max_retries,
timeout=timeout,
client=client,
) # type: ignore
openai_client = self._get_openai_client(
is_async=False,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = openai_client.audio.speech.create(
model=model,
voice=voice, # type: ignore
input=input,
**optional_params,
)
return response
async def async_audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
organization: Optional[str],
project: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
client=None,
) -> HttpxBinaryResponseContent:
openai_client = self._get_openai_client(
is_async=True,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = await openai_client.audio.speech.create(
model=model,
voice=voice, # type: ignore
input=input,
**optional_params,
)
return response
async def ahealth_check(
self,
model: Optional[str],
api_key: str,
timeout: float,
mode: str,
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
organization: Optional[str] = None,
api_base: Optional[str] = None,
):
client = AsyncOpenAI(
api_key=api_key,
timeout=timeout,
organization=organization,
base_url=api_base,
)
if model is None and mode != "image_generation":
raise Exception("model is not set")
completion = None
if mode == "completion":
completion = await client.completions.with_raw_response.create(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "chat":
if messages is None:
raise Exception("messages is not set")
completion = await client.chat.completions.with_raw_response.create(
model=model, # type: ignore
messages=messages, # type: ignore
)
elif mode == "embedding":
if input is None:
raise Exception("input is not set")
completion = await client.embeddings.with_raw_response.create(
model=model, # type: ignore
input=input, # type: ignore
)
elif mode == "image_generation":
if prompt is None:
raise Exception("prompt is not set")
completion = await client.images.with_raw_response.generate(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_transcription":
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
audio_file = open(file_path, "rb")
completion = await client.audio.transcriptions.with_raw_response.create(
file=audio_file,
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_speech":
# Get the current directory of the file being run
completion = await client.audio.speech.with_raw_response.create(
model=model, # type: ignore
input=prompt, # type: ignore
voice="alloy",
)
else:
raise ValueError("mode not set, passed in mode: " + mode)
response = {}
if completion is None or not hasattr(completion, "headers"):
raise Exception("invalid completion response")
if (
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
): # not provided for dall-e requests
response["x-ratelimit-remaining-requests"] = completion.headers[
"x-ratelimit-remaining-requests"
]
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens"
]
return response
class OpenAITextCompletion(BaseLLM):
_client_session: httpx.Client
def __init__(self) -> None:
super().__init__()
self._client_session = self.create_client_session()
def validate_environment(self, api_key):
headers = {
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def completion(
self,
model_response: ModelResponse,
api_key: str,
model: str,
messages: list,
timeout: float,
logging_obj: LiteLLMLoggingObj,
print_verbose: Optional[Callable] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
optional_params=None,
litellm_params=None,
logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None,
):
super().completion()
exception_mapping_worked = False
try:
if headers is None:
headers = self.validate_environment(api_key=api_key)
if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages")
if (
len(messages) > 0
and "content" in messages[0]
and type(messages[0]["content"]) == list
):
prompt = messages[0]["content"]
else:
prompt = [message["content"] for message in messages] # type: ignore
# don't send max retries to the api, if set
data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"complete_input_dict": data,
},
)
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries,
client=client,
organization=organization,
)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
elif optional_params.get("stream", False):
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries, # type: ignore
client=client,
organization=organization,
)
else:
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
raw_response = openai_client.completions.with_raw_response.create(**data) # type: ignore
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return TextCompletionResponse(**response_json)
except Exception as e:
raise e
async def acompletion(
self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
prompt: str,
api_key: str,
model: str,
timeout: float,
max_retries=None,
organization: Optional[str] = None,
client=None,
):
try:
if client is None:
openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
raw_response = await openai_aclient.completions.with_raw_response.create(
**data
)
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
response_obj = TextCompletionResponse(**response_json)
response_obj._hidden_params.original_response = json.dumps(response_json)
return response_obj
except Exception as e:
raise e
def streaming(
self,
logging_obj,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
api_base: Optional[str] = None,
max_retries=None,
client=None,
organization=None,
):
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
try:
raw_response = openai_client.completions.with_raw_response.create(**data)
response = raw_response.parse()
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
for chunk in streamwrapper:
yield chunk
async def async_streaming(
self,
logging_obj,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
api_base: Optional[str] = None,
client=None,
max_retries=None,
organization=None,
):
if client is None:
openai_client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
raw_response = await openai_client.completions.with_raw_response.create(**data)
response = raw_response.parse()
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
class OpenAIFilesAPI(BaseLLM):
"""
OpenAI methods to support for batches
- create_file()
- retrieve_file()
- list_files()
- delete_file()
- file_content()
- update_file()
"""
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
received_args = locals()
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
if _is_async is True:
openai_client = AsyncOpenAI(**data)
else:
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_file(
self,
create_file_data: CreateFileRequest,
openai_client: AsyncOpenAI,
) -> FileObject:
response = await openai_client.files.create(**create_file_data)
return response
def create_file(
self,
_is_async: bool,
create_file_data: CreateFileRequest,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_file( # type: ignore
create_file_data=create_file_data, openai_client=openai_client
)
response = openai_client.files.create(**create_file_data)
return response
async def afile_content(
self,
file_content_request: FileContentRequest,
openai_client: AsyncOpenAI,
) -> HttpxBinaryResponseContent:
response = await openai_client.files.content(**file_content_request)
return response
def file_content(
self,
_is_async: bool,
file_content_request: FileContentRequest,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.afile_content( # type: ignore
file_content_request=file_content_request,
openai_client=openai_client,
)
response = openai_client.files.content(**file_content_request)
return response
async def aretrieve_file(
self,
file_id: str,
openai_client: AsyncOpenAI,
) -> FileObject:
response = await openai_client.files.retrieve(file_id=file_id)
return response
def retrieve_file(
self,
_is_async: bool,
file_id: str,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_file( # type: ignore
file_id=file_id,
openai_client=openai_client,
)
response = openai_client.files.retrieve(file_id=file_id)
return response
async def adelete_file(
self,
file_id: str,
openai_client: AsyncOpenAI,
) -> FileDeleted:
response = await openai_client.files.delete(file_id=file_id)
return response
def delete_file(
self,
_is_async: bool,
file_id: str,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.adelete_file( # type: ignore
file_id=file_id,
openai_client=openai_client,
)
response = openai_client.files.delete(file_id=file_id)
return response
async def alist_files(
self,
openai_client: AsyncOpenAI,
purpose: Optional[str] = None,
):
if isinstance(purpose, str):
response = await openai_client.files.list(purpose=purpose)
else:
response = await openai_client.files.list()
return response
def list_files(
self,
_is_async: bool,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
purpose: Optional[str] = None,
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.alist_files( # type: ignore
purpose=purpose,
openai_client=openai_client,
)
if isinstance(purpose, str):
response = openai_client.files.list(purpose=purpose)
else:
response = openai_client.files.list()
return response
class OpenAIBatchesAPI(BaseLLM):
"""
OpenAI methods to support for batches
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
received_args = locals()
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
if _is_async is True:
openai_client = AsyncOpenAI(**data)
else:
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_batch(
self,
create_batch_data: CreateBatchRequest,
openai_client: AsyncOpenAI,
) -> Batch:
response = await openai_client.batches.create(**create_batch_data)
return response
def create_batch(
self,
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, openai_client=openai_client
)
response = openai_client.batches.create(**create_batch_data)
return response
async def aretrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
openai_client: AsyncOpenAI,
) -> Batch:
verbose_logger.debug("retrieving batch, args= %s", retrieve_batch_data)
response = await openai_client.batches.retrieve(**retrieve_batch_data)
return response
def retrieve_batch(
self,
_is_async: bool,
retrieve_batch_data: RetrieveBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, openai_client=openai_client
)
response = openai_client.batches.retrieve(**retrieve_batch_data)
return response
def cancel_batch(
self,
_is_async: bool,
cancel_batch_data: CancelBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
response = openai_client.batches.cancel(**cancel_batch_data)
return response
async def alist_batches(
self,
openai_client: AsyncOpenAI,
after: Optional[str] = None,
limit: Optional[int] = None,
):
verbose_logger.debug("listing batches, after= %s, limit= %s", after, limit)
response = await openai_client.batches.list(after=after, limit=limit) # type: ignore
return response
def list_batches(
self,
_is_async: bool,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
after: Optional[str] = None,
limit: Optional[int] = None,
client: Optional[OpenAI] = None,
):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.alist_batches( # type: ignore
openai_client=openai_client, after=after, limit=limit
)
response = openai_client.batches.list(after=after, limit=limit) # type: ignore
return response
class OpenAIAssistantsAPI(BaseLLM):
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
def async_get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> AsyncOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = AsyncOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
### ASSISTANTS ###
async def async_get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> AsyncCursorPage[Assistant]:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.assistants.list()
return response
# fmt: off
@overload
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
aget_assistants: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
...
@overload
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
aget_assistants: Optional[Literal[False]],
) -> SyncCursorPage[Assistant]:
...
# fmt: on
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
aget_assistants=None,
):
if aget_assistants is not None and aget_assistants == True:
return self.async_get_assistants(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.assistants.list()
return response
# Create Assistant
async def async_create_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
create_assistant_data: dict,
) -> Assistant:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.assistants.create(**create_assistant_data)
return response
def create_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
create_assistant_data: dict,
client=None,
async_create_assistants=None,
):
if async_create_assistants is not None and async_create_assistants == True:
return self.async_create_assistants(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
create_assistant_data=create_assistant_data,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.assistants.create(**create_assistant_data)
return response
# Delete Assistant
async def async_delete_assistant(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
assistant_id: str,
) -> AssistantDeleted:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.assistants.delete(assistant_id=assistant_id)
return response
def delete_assistant(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
assistant_id: str,
client=None,
async_delete_assistants=None,
):
if async_delete_assistants is not None and async_delete_assistants == True:
return self.async_delete_assistant(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
assistant_id=assistant_id,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.assistants.delete(assistant_id=assistant_id)
return response
### MESSAGES ###
async def a_add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> OpenAIMessage:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data # type: ignore
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
# fmt: off
@overload
def add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
a_add_message: Literal[True],
) -> Coroutine[None, None, OpenAIMessage]:
...
@overload
def add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
a_add_message: Optional[Literal[False]],
) -> OpenAIMessage:
...
# fmt: on
def add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
a_add_message: Optional[bool] = None,
):
if a_add_message is not None and a_add_message == True:
return self.a_add_message(
thread_id=thread_id,
message_data=message_data,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data # type: ignore
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
async def async_get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> AsyncCursorPage[OpenAIMessage]:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
# fmt: off
@overload
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
aget_messages: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
...
@overload
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
aget_messages: Optional[Literal[False]],
) -> SyncCursorPage[OpenAIMessage]:
...
# fmt: on
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
aget_messages=None,
):
if aget_messages is not None and aget_messages == True:
return self.async_get_messages(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
### THREADS ###
async def async_create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
) -> Thread:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())
# fmt: off
@overload
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AsyncOpenAI],
acreate_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[OpenAI],
acreate_thread: Optional[Literal[False]],
) -> Thread:
...
# fmt: on
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client=None,
acreate_thread=None,
):
"""
Here's an example:
```
from litellm.llms.openai import OpenAIAssistantsAPI, MessageData
# create thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
openai_api.create_thread(messages=[message])
```
"""
if acreate_thread is not None and acreate_thread == True:
return self.async_create_thread(
metadata=metadata,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
messages=messages,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())
async def async_get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> Thread:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
# fmt: off
@overload
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
aget_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
aget_thread: Optional[Literal[False]],
) -> Thread:
...
# fmt: on
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
aget_thread=None,
):
if aget_thread is not None and aget_thread == True:
return self.async_get_thread(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
def delete_thread(self):
pass
### RUNS ###
async def arun_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> Run:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response
def async_run_thread_stream(
self,
client: AsyncOpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
def run_thread_stream(
self,
client: OpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AssistantStreamManager[AssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
# fmt: off
@overload
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client,
arun_thread: Literal[True],
event_handler: Optional[AssistantEventHandler],
) -> Coroutine[None, None, Run]:
...
@overload
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client,
arun_thread: Optional[Literal[False]],
event_handler: Optional[AssistantEventHandler],
) -> Run:
...
# fmt: on
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None,
):
if arun_thread is not None and arun_thread == True:
if stream is not None and stream == True:
_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
return self.async_run_thread_stream(
client=_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
if stream is not None and stream == True:
return self.run_thread_stream(
client=openai_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response