forked from phoenix/litellm-mirror
3566 lines
124 KiB
Python
3566 lines
124 KiB
Python
import hashlib
|
|
import json
|
|
import os
|
|
import time
|
|
import traceback
|
|
import types
|
|
from typing import (
|
|
Any,
|
|
BinaryIO,
|
|
Callable,
|
|
Coroutine,
|
|
Iterable,
|
|
Literal,
|
|
Optional,
|
|
Union,
|
|
)
|
|
|
|
import httpx
|
|
import openai
|
|
from openai import AsyncOpenAI, OpenAI
|
|
from openai.types.beta.assistant_deleted import AssistantDeleted
|
|
from openai.types.file_deleted import FileDeleted
|
|
from pydantic import BaseModel
|
|
from typing_extensions import overload, override
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_logger
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
from litellm.types.utils import ProviderField
|
|
from litellm.utils import (
|
|
Choices,
|
|
CustomStreamWrapper,
|
|
Message,
|
|
ModelResponse,
|
|
TextCompletionResponse,
|
|
TranscriptionResponse,
|
|
Usage,
|
|
convert_to_model_response_object,
|
|
)
|
|
|
|
from ..types.llms.openai import *
|
|
from .base import BaseLLM
|
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
|
|
|
|
|
class OpenAIError(Exception):
|
|
def __init__(
|
|
self,
|
|
status_code,
|
|
message,
|
|
request: Optional[httpx.Request] = None,
|
|
response: Optional[httpx.Response] = None,
|
|
headers: Optional[httpx.Headers] = None,
|
|
):
|
|
self.status_code = status_code
|
|
self.message = message
|
|
self.headers = headers
|
|
if request:
|
|
self.request = request
|
|
else:
|
|
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
|
if response:
|
|
self.response = response
|
|
else:
|
|
self.response = httpx.Response(
|
|
status_code=status_code, request=self.request
|
|
)
|
|
super().__init__(
|
|
self.message
|
|
) # Call the base class constructor with the parameters it needs
|
|
|
|
|
|
class MistralConfig:
|
|
"""
|
|
Reference: https://docs.mistral.ai/api/
|
|
|
|
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
|
|
|
|
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
|
|
|
|
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
|
|
|
|
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
|
|
|
|
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
|
|
|
|
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
|
|
|
|
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
|
|
|
|
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
|
|
|
|
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
|
|
|
|
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
|
|
"""
|
|
|
|
temperature: Optional[int] = None
|
|
top_p: Optional[int] = None
|
|
max_tokens: Optional[int] = None
|
|
tools: Optional[list] = None
|
|
tool_choice: Optional[Literal["auto", "any", "none"]] = None
|
|
random_seed: Optional[int] = None
|
|
safe_prompt: Optional[bool] = None
|
|
response_format: Optional[dict] = None
|
|
stop: Optional[Union[str, list]] = None
|
|
|
|
def __init__(
|
|
self,
|
|
temperature: Optional[int] = None,
|
|
top_p: Optional[int] = None,
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[list] = None,
|
|
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
|
|
random_seed: Optional[int] = None,
|
|
safe_prompt: Optional[bool] = None,
|
|
response_format: Optional[dict] = None,
|
|
stop: Optional[Union[str, list]] = None,
|
|
) -> None:
|
|
locals_ = locals().copy()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params(self):
|
|
return [
|
|
"stream",
|
|
"temperature",
|
|
"top_p",
|
|
"max_tokens",
|
|
"tools",
|
|
"tool_choice",
|
|
"seed",
|
|
"stop",
|
|
"response_format",
|
|
]
|
|
|
|
def _map_tool_choice(self, tool_choice: str) -> str:
|
|
if tool_choice == "auto" or tool_choice == "none":
|
|
return tool_choice
|
|
elif tool_choice == "required":
|
|
return "any"
|
|
else: # openai 'tool_choice' object param not supported by Mistral API
|
|
return "any"
|
|
|
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
|
for param, value in non_default_params.items():
|
|
if param == "max_tokens":
|
|
optional_params["max_tokens"] = value
|
|
if param == "tools":
|
|
optional_params["tools"] = value
|
|
if param == "stream" and value is True:
|
|
optional_params["stream"] = value
|
|
if param == "temperature":
|
|
optional_params["temperature"] = value
|
|
if param == "top_p":
|
|
optional_params["top_p"] = value
|
|
if param == "stop":
|
|
optional_params["stop"] = value
|
|
if param == "tool_choice" and isinstance(value, str):
|
|
optional_params["tool_choice"] = self._map_tool_choice(
|
|
tool_choice=value
|
|
)
|
|
if param == "seed":
|
|
optional_params["extra_body"] = {"random_seed": value}
|
|
if param == "response_format":
|
|
optional_params["response_format"] = value
|
|
return optional_params
|
|
|
|
|
|
class MistralEmbeddingConfig:
|
|
"""
|
|
Reference: https://docs.mistral.ai/api/#operation/createEmbedding
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
) -> None:
|
|
locals_ = locals().copy()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params(self):
|
|
return [
|
|
"encoding_format",
|
|
]
|
|
|
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
|
for param, value in non_default_params.items():
|
|
if param == "encoding_format":
|
|
optional_params["encoding_format"] = value
|
|
return optional_params
|
|
|
|
|
|
class AzureAIStudioConfig:
|
|
def get_required_params(self) -> List[ProviderField]:
|
|
"""For a given provider, return it's required fields with a description"""
|
|
return [
|
|
ProviderField(
|
|
field_name="api_key",
|
|
field_type="string",
|
|
field_description="Your Azure AI Studio API Key.",
|
|
field_value="zEJ...",
|
|
),
|
|
ProviderField(
|
|
field_name="api_base",
|
|
field_type="string",
|
|
field_description="Your Azure AI Studio API Base.",
|
|
field_value="https://Mistral-serverless.",
|
|
),
|
|
]
|
|
|
|
|
|
class DeepInfraConfig:
|
|
"""
|
|
Reference: https://deepinfra.com/docs/advanced/openai_api
|
|
|
|
The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
|
|
"""
|
|
|
|
frequency_penalty: Optional[int] = None
|
|
function_call: Optional[Union[str, dict]] = None
|
|
functions: Optional[list] = None
|
|
logit_bias: Optional[dict] = None
|
|
max_tokens: Optional[int] = None
|
|
n: Optional[int] = None
|
|
presence_penalty: Optional[int] = None
|
|
stop: Optional[Union[str, list]] = None
|
|
temperature: Optional[int] = None
|
|
top_p: Optional[int] = None
|
|
response_format: Optional[dict] = None
|
|
tools: Optional[list] = None
|
|
tool_choice: Optional[Union[str, dict]] = None
|
|
|
|
def __init__(
|
|
self,
|
|
frequency_penalty: Optional[int] = None,
|
|
function_call: Optional[Union[str, dict]] = None,
|
|
functions: Optional[list] = None,
|
|
logit_bias: Optional[dict] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
presence_penalty: Optional[int] = None,
|
|
stop: Optional[Union[str, list]] = None,
|
|
temperature: Optional[int] = None,
|
|
top_p: Optional[int] = None,
|
|
response_format: Optional[dict] = None,
|
|
tools: Optional[list] = None,
|
|
tool_choice: Optional[Union[str, dict]] = None,
|
|
) -> None:
|
|
locals_ = locals().copy()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params(self):
|
|
return [
|
|
"stream",
|
|
"frequency_penalty",
|
|
"function_call",
|
|
"functions",
|
|
"logit_bias",
|
|
"max_tokens",
|
|
"n",
|
|
"presence_penalty",
|
|
"stop",
|
|
"temperature",
|
|
"top_p",
|
|
"response_format",
|
|
"tools",
|
|
"tool_choice",
|
|
]
|
|
|
|
def map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
supported_openai_params = self.get_supported_openai_params()
|
|
for param, value in non_default_params.items():
|
|
if (
|
|
param == "temperature"
|
|
and value == 0
|
|
and model == "mistralai/Mistral-7B-Instruct-v0.1"
|
|
): # this model does no support temperature == 0
|
|
value = 0.0001 # close to 0
|
|
if param == "tool_choice":
|
|
if (
|
|
value != "auto" and value != "none"
|
|
): # https://deepinfra.com/docs/advanced/function_calling
|
|
## UNSUPPORTED TOOL CHOICE VALUE
|
|
if litellm.drop_params is True or drop_params is True:
|
|
value = None
|
|
else:
|
|
raise litellm.utils.UnsupportedParamsError(
|
|
message="Deepinfra doesn't support tool_choice={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
|
|
value
|
|
),
|
|
status_code=400,
|
|
)
|
|
if param in supported_openai_params:
|
|
if value is not None:
|
|
optional_params[param] = value
|
|
return optional_params
|
|
|
|
|
|
class GroqConfig:
|
|
"""
|
|
Reference: https://deepinfra.com/docs/advanced/openai_api
|
|
|
|
The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
|
|
"""
|
|
|
|
frequency_penalty: Optional[int] = None
|
|
function_call: Optional[Union[str, dict]] = None
|
|
functions: Optional[list] = None
|
|
logit_bias: Optional[dict] = None
|
|
max_tokens: Optional[int] = None
|
|
n: Optional[int] = None
|
|
presence_penalty: Optional[int] = None
|
|
stop: Optional[Union[str, list]] = None
|
|
temperature: Optional[int] = None
|
|
top_p: Optional[int] = None
|
|
response_format: Optional[dict] = None
|
|
tools: Optional[list] = None
|
|
tool_choice: Optional[Union[str, dict]] = None
|
|
|
|
def __init__(
|
|
self,
|
|
frequency_penalty: Optional[int] = None,
|
|
function_call: Optional[Union[str, dict]] = None,
|
|
functions: Optional[list] = None,
|
|
logit_bias: Optional[dict] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
presence_penalty: Optional[int] = None,
|
|
stop: Optional[Union[str, list]] = None,
|
|
temperature: Optional[int] = None,
|
|
top_p: Optional[int] = None,
|
|
response_format: Optional[dict] = None,
|
|
tools: Optional[list] = None,
|
|
tool_choice: Optional[Union[str, dict]] = None,
|
|
) -> None:
|
|
locals_ = locals().copy()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params_stt(self):
|
|
return [
|
|
"prompt",
|
|
"response_format",
|
|
"temperature",
|
|
"language",
|
|
]
|
|
|
|
def get_supported_openai_response_formats_stt(self) -> List[str]:
|
|
return ["json", "verbose_json", "text"]
|
|
|
|
def map_openai_params_stt(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
response_formats = self.get_supported_openai_response_formats_stt()
|
|
for param, value in non_default_params.items():
|
|
if param == "response_format":
|
|
if value in response_formats:
|
|
optional_params[param] = value
|
|
else:
|
|
if litellm.drop_params is True or drop_params is True:
|
|
pass
|
|
else:
|
|
raise litellm.utils.UnsupportedParamsError(
|
|
message="Groq doesn't support response_format={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
|
|
value
|
|
),
|
|
status_code=400,
|
|
)
|
|
else:
|
|
optional_params[param] = value
|
|
return optional_params
|
|
|
|
|
|
class OpenAIConfig:
|
|
"""
|
|
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
|
|
|
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
|
|
|
|
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
|
|
|
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
|
|
|
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
|
|
|
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
|
|
|
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
|
|
|
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
|
|
|
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
|
|
|
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
|
|
|
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
|
|
|
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
|
"""
|
|
|
|
frequency_penalty: Optional[int] = None
|
|
function_call: Optional[Union[str, dict]] = None
|
|
functions: Optional[list] = None
|
|
logit_bias: Optional[dict] = None
|
|
max_tokens: Optional[int] = None
|
|
n: Optional[int] = None
|
|
presence_penalty: Optional[int] = None
|
|
stop: Optional[Union[str, list]] = None
|
|
temperature: Optional[int] = None
|
|
top_p: Optional[int] = None
|
|
response_format: Optional[dict] = None
|
|
|
|
def __init__(
|
|
self,
|
|
frequency_penalty: Optional[int] = None,
|
|
function_call: Optional[Union[str, dict]] = None,
|
|
functions: Optional[list] = None,
|
|
logit_bias: Optional[dict] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
presence_penalty: Optional[int] = None,
|
|
stop: Optional[Union[str, list]] = None,
|
|
temperature: Optional[int] = None,
|
|
top_p: Optional[int] = None,
|
|
response_format: Optional[dict] = None,
|
|
) -> None:
|
|
locals_ = locals().copy()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params(self, model: str) -> list:
|
|
base_params = [
|
|
"frequency_penalty",
|
|
"logit_bias",
|
|
"logprobs",
|
|
"top_logprobs",
|
|
"max_tokens",
|
|
"n",
|
|
"presence_penalty",
|
|
"seed",
|
|
"stop",
|
|
"stream",
|
|
"stream_options",
|
|
"temperature",
|
|
"top_p",
|
|
"tools",
|
|
"tool_choice",
|
|
"function_call",
|
|
"functions",
|
|
"max_retries",
|
|
"extra_headers",
|
|
"parallel_tool_calls",
|
|
] # works across all models
|
|
|
|
model_specific_params = []
|
|
if (
|
|
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
|
): # gpt-4 does not support 'response_format'
|
|
model_specific_params.append("response_format")
|
|
|
|
if (
|
|
model in litellm.open_ai_chat_completion_models
|
|
) or model in litellm.open_ai_text_completion_models:
|
|
model_specific_params.append(
|
|
"user"
|
|
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
|
|
return base_params + model_specific_params
|
|
|
|
def map_openai_params(
|
|
self, non_default_params: dict, optional_params: dict, model: str
|
|
) -> dict:
|
|
supported_openai_params = self.get_supported_openai_params(model)
|
|
for param, value in non_default_params.items():
|
|
if param in supported_openai_params:
|
|
optional_params[param] = value
|
|
return optional_params
|
|
|
|
|
|
class OpenAITextCompletionConfig:
|
|
"""
|
|
Reference: https://platform.openai.com/docs/api-reference/completions/create
|
|
|
|
The class `OpenAITextCompletionConfig` provides configuration for the OpenAI's text completion API interface. Below are the parameters:
|
|
|
|
- `best_of` (integer or null): This optional parameter generates server-side completions and returns the one with the highest log probability per token.
|
|
|
|
- `echo` (boolean or null): This optional parameter will echo back the prompt in addition to the completion.
|
|
|
|
- `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line.
|
|
|
|
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
|
|
|
- `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens.
|
|
|
|
- `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion.
|
|
|
|
- `n` (integer or null): This optional parameter sets how many completions to generate for each prompt.
|
|
|
|
- `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics.
|
|
|
|
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
|
|
|
- `suffix` (string or null): Defines the suffix that comes after a completion of inserted text.
|
|
|
|
- `temperature` (number or null): This optional parameter defines the sampling temperature to use.
|
|
|
|
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
|
"""
|
|
|
|
best_of: Optional[int] = None
|
|
echo: Optional[bool] = None
|
|
frequency_penalty: Optional[int] = None
|
|
logit_bias: Optional[dict] = None
|
|
logprobs: Optional[int] = None
|
|
max_tokens: Optional[int] = None
|
|
n: Optional[int] = None
|
|
presence_penalty: Optional[int] = None
|
|
stop: Optional[Union[str, list]] = None
|
|
suffix: Optional[str] = None
|
|
temperature: Optional[float] = None
|
|
top_p: Optional[float] = None
|
|
|
|
def __init__(
|
|
self,
|
|
best_of: Optional[int] = None,
|
|
echo: Optional[bool] = None,
|
|
frequency_penalty: Optional[int] = None,
|
|
logit_bias: Optional[dict] = None,
|
|
logprobs: Optional[int] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
presence_penalty: Optional[int] = None,
|
|
stop: Optional[Union[str, list]] = None,
|
|
suffix: Optional[str] = None,
|
|
temperature: Optional[float] = None,
|
|
top_p: Optional[float] = None,
|
|
) -> None:
|
|
locals_ = locals().copy()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def convert_to_chat_model_response_object(
|
|
self,
|
|
response_object: Optional[TextCompletionResponse] = None,
|
|
model_response_object: Optional[ModelResponse] = None,
|
|
):
|
|
try:
|
|
## RESPONSE OBJECT
|
|
if response_object is None or model_response_object is None:
|
|
raise ValueError("Error in response object format")
|
|
choice_list = []
|
|
for idx, choice in enumerate(response_object["choices"]):
|
|
message = Message(
|
|
content=choice["text"],
|
|
role="assistant",
|
|
)
|
|
choice = Choices(
|
|
finish_reason=choice["finish_reason"], index=idx, message=message
|
|
)
|
|
choice_list.append(choice)
|
|
model_response_object.choices = choice_list
|
|
|
|
if "usage" in response_object:
|
|
setattr(model_response_object, "usage", response_object["usage"])
|
|
|
|
if "id" in response_object:
|
|
model_response_object.id = response_object["id"]
|
|
|
|
if "model" in response_object:
|
|
model_response_object.model = response_object["model"]
|
|
|
|
model_response_object._hidden_params["original_response"] = (
|
|
response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
|
|
)
|
|
return model_response_object
|
|
except Exception as e:
|
|
raise e
|
|
|
|
|
|
class OpenAIChatCompletion(BaseLLM):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def _get_openai_client(
|
|
self,
|
|
is_async: bool,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
|
max_retries: Optional[int] = 2,
|
|
organization: Optional[str] = None,
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
):
|
|
args = locals()
|
|
if client is None:
|
|
if not isinstance(max_retries, int):
|
|
raise OpenAIError(
|
|
status_code=422,
|
|
message="max retries must be an int. Passed in value: {}".format(
|
|
max_retries
|
|
),
|
|
)
|
|
# Creating a new OpenAI Client
|
|
# check in memory cache before creating a new one
|
|
# Convert the API key to bytes
|
|
hashed_api_key = None
|
|
if api_key is not None:
|
|
hash_object = hashlib.sha256(api_key.encode())
|
|
# Hexadecimal representation of the hash
|
|
hashed_api_key = hash_object.hexdigest()
|
|
|
|
_cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization},is_async={is_async}"
|
|
|
|
if _cache_key in litellm.in_memory_llm_clients_cache:
|
|
return litellm.in_memory_llm_clients_cache[_cache_key]
|
|
if is_async:
|
|
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
|
|
api_key=api_key,
|
|
base_url=api_base,
|
|
http_client=litellm.aclient_session,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
)
|
|
else:
|
|
_new_client = OpenAI(
|
|
api_key=api_key,
|
|
base_url=api_base,
|
|
http_client=litellm.client_session,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
)
|
|
|
|
## SAVE CACHE KEY
|
|
litellm.in_memory_llm_clients_cache[_cache_key] = _new_client
|
|
return _new_client
|
|
|
|
else:
|
|
return client
|
|
|
|
async def make_openai_chat_completion_request(
|
|
self,
|
|
openai_aclient: AsyncOpenAI,
|
|
data: dict,
|
|
timeout: Union[float, httpx.Timeout],
|
|
) -> Tuple[dict, BaseModel]:
|
|
"""
|
|
Helper to:
|
|
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
|
- call chat.completions.create by default
|
|
"""
|
|
try:
|
|
raw_response = (
|
|
await openai_aclient.chat.completions.with_raw_response.create(
|
|
**data, timeout=timeout
|
|
)
|
|
)
|
|
|
|
if hasattr(raw_response, "headers"):
|
|
headers = dict(raw_response.headers)
|
|
else:
|
|
headers = {}
|
|
response = raw_response.parse()
|
|
return headers, response
|
|
except OpenAIError as e:
|
|
raise e
|
|
except Exception as e:
|
|
status_code = getattr(e, "status_code", 500)
|
|
error_headers = getattr(e, "headers", None)
|
|
raise OpenAIError(
|
|
status_code=status_code, message=str(e), headers=error_headers
|
|
)
|
|
|
|
def make_sync_openai_chat_completion_request(
|
|
self,
|
|
openai_client: OpenAI,
|
|
data: dict,
|
|
timeout: Union[float, httpx.Timeout],
|
|
) -> Tuple[dict, BaseModel]:
|
|
"""
|
|
Helper to:
|
|
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
|
- call chat.completions.create by default
|
|
"""
|
|
try:
|
|
raw_response = openai_client.chat.completions.with_raw_response.create(
|
|
**data, timeout=timeout
|
|
)
|
|
|
|
if hasattr(raw_response, "headers"):
|
|
headers = dict(raw_response.headers)
|
|
else:
|
|
headers = {}
|
|
response = raw_response.parse()
|
|
return headers, response
|
|
except OpenAIError as e:
|
|
raise e
|
|
except Exception as e:
|
|
status_code = getattr(e, "status_code", 500)
|
|
error_headers = getattr(e, "headers", None)
|
|
raise OpenAIError(
|
|
status_code=status_code, message=str(e), headers=error_headers
|
|
)
|
|
|
|
def completion(
|
|
self,
|
|
model_response: ModelResponse,
|
|
timeout: Union[float, httpx.Timeout],
|
|
optional_params: dict,
|
|
model: Optional[str] = None,
|
|
messages: Optional[list] = None,
|
|
print_verbose: Optional[Callable] = None,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
acompletion: bool = False,
|
|
logging_obj=None,
|
|
litellm_params=None,
|
|
logger_fn=None,
|
|
headers: Optional[dict] = None,
|
|
custom_prompt_dict: dict = {},
|
|
client=None,
|
|
organization: Optional[str] = None,
|
|
custom_llm_provider: Optional[str] = None,
|
|
drop_params: Optional[bool] = None,
|
|
):
|
|
super().completion()
|
|
exception_mapping_worked = False
|
|
try:
|
|
if headers:
|
|
optional_params["extra_headers"] = headers
|
|
if model is None or messages is None:
|
|
raise OpenAIError(status_code=422, message="Missing model or messages")
|
|
|
|
if not isinstance(timeout, float) and not isinstance(
|
|
timeout, httpx.Timeout
|
|
):
|
|
raise OpenAIError(
|
|
status_code=422,
|
|
message="Timeout needs to be a float or httpx.Timeout",
|
|
)
|
|
|
|
if custom_llm_provider is not None and custom_llm_provider != "openai":
|
|
model_response.model = f"{custom_llm_provider}/{model}"
|
|
# process all OpenAI compatible provider logic here
|
|
if custom_llm_provider == "mistral":
|
|
# check if message content passed in as list, and not string
|
|
messages = prompt_factory(
|
|
model=model,
|
|
messages=messages,
|
|
custom_llm_provider=custom_llm_provider,
|
|
)
|
|
if custom_llm_provider == "perplexity" and messages is not None:
|
|
# check if messages.name is passed + supported, if not supported remove
|
|
messages = prompt_factory(
|
|
model=model,
|
|
messages=messages,
|
|
custom_llm_provider=custom_llm_provider,
|
|
)
|
|
|
|
for _ in range(
|
|
2
|
|
): # if call fails due to alternating messages, retry with reformatted message
|
|
data = {"model": model, "messages": messages, **optional_params}
|
|
|
|
try:
|
|
max_retries = data.pop("max_retries", 2)
|
|
if acompletion is True:
|
|
if optional_params.get("stream", False):
|
|
return self.async_streaming(
|
|
logging_obj=logging_obj,
|
|
headers=headers,
|
|
data=data,
|
|
model=model,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
timeout=timeout,
|
|
client=client,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
drop_params=drop_params,
|
|
)
|
|
else:
|
|
return self.acompletion(
|
|
data=data,
|
|
headers=headers,
|
|
logging_obj=logging_obj,
|
|
model_response=model_response,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
timeout=timeout,
|
|
client=client,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
drop_params=drop_params,
|
|
)
|
|
elif optional_params.get("stream", False):
|
|
return self.streaming(
|
|
logging_obj=logging_obj,
|
|
headers=headers,
|
|
data=data,
|
|
model=model,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
timeout=timeout,
|
|
client=client,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
)
|
|
else:
|
|
if not isinstance(max_retries, int):
|
|
raise OpenAIError(
|
|
status_code=422, message="max retries must be an int"
|
|
)
|
|
|
|
openai_client = self._get_openai_client(
|
|
is_async=False,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=messages,
|
|
api_key=openai_client.api_key,
|
|
additional_args={
|
|
"headers": headers,
|
|
"api_base": openai_client._base_url._uri_reference,
|
|
"acompletion": acompletion,
|
|
"complete_input_dict": data,
|
|
},
|
|
)
|
|
|
|
headers, response = (
|
|
self.make_sync_openai_chat_completion_request(
|
|
openai_client=openai_client,
|
|
data=data,
|
|
timeout=timeout,
|
|
)
|
|
)
|
|
|
|
logging_obj.model_call_details["response_headers"] = headers
|
|
stringified_response = response.model_dump()
|
|
logging_obj.post_call(
|
|
input=messages,
|
|
api_key=api_key,
|
|
original_response=stringified_response,
|
|
additional_args={"complete_input_dict": data},
|
|
)
|
|
return convert_to_model_response_object(
|
|
response_object=stringified_response,
|
|
model_response_object=model_response,
|
|
_response_headers=headers,
|
|
)
|
|
except openai.UnprocessableEntityError as e:
|
|
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
|
|
if litellm.drop_params is True or drop_params is True:
|
|
invalid_params: List[str] = []
|
|
if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore
|
|
detail = e.body.get("detail") # type: ignore
|
|
if (
|
|
isinstance(detail, List)
|
|
and len(detail) > 0
|
|
and isinstance(detail[0], dict)
|
|
):
|
|
for error_dict in detail:
|
|
if (
|
|
error_dict.get("loc")
|
|
and isinstance(error_dict.get("loc"), list)
|
|
and len(error_dict.get("loc")) == 2
|
|
):
|
|
invalid_params.append(error_dict["loc"][1])
|
|
|
|
new_data = {}
|
|
for k, v in optional_params.items():
|
|
if k not in invalid_params:
|
|
new_data[k] = v
|
|
optional_params = new_data
|
|
else:
|
|
raise e
|
|
# e.message
|
|
except Exception as e:
|
|
if print_verbose is not None:
|
|
print_verbose(f"openai.py: Received openai error - {str(e)}")
|
|
if (
|
|
"Conversation roles must alternate user/assistant" in str(e)
|
|
or "user and assistant roles should be alternating" in str(e)
|
|
) and messages is not None:
|
|
if print_verbose is not None:
|
|
print_verbose("openai.py: REFORMATS THE MESSAGE!")
|
|
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
|
new_messages = []
|
|
for i in range(len(messages) - 1): # type: ignore
|
|
new_messages.append(messages[i])
|
|
if messages[i]["role"] == messages[i + 1]["role"]:
|
|
if messages[i]["role"] == "user":
|
|
new_messages.append(
|
|
{"role": "assistant", "content": ""}
|
|
)
|
|
else:
|
|
new_messages.append({"role": "user", "content": ""})
|
|
new_messages.append(messages[-1])
|
|
messages = new_messages
|
|
elif (
|
|
"Last message must have role `user`" in str(e)
|
|
) and messages is not None:
|
|
new_messages = messages
|
|
new_messages.append({"role": "user", "content": ""})
|
|
messages = new_messages
|
|
elif (
|
|
"unknown field: parameter index is not a valid field" in str(e)
|
|
) and "tools" in data:
|
|
litellm.remove_index_from_tool_calls(
|
|
tool_calls=data["tools"], messages=messages
|
|
)
|
|
else:
|
|
raise e
|
|
except OpenAIError as e:
|
|
exception_mapping_worked = True
|
|
raise e
|
|
except Exception as e:
|
|
if hasattr(e, "status_code"):
|
|
raise OpenAIError(status_code=e.status_code, message=str(e))
|
|
else:
|
|
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
|
|
|
async def acompletion(
|
|
self,
|
|
data: dict,
|
|
model_response: ModelResponse,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
timeout: Union[float, httpx.Timeout],
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
organization: Optional[str] = None,
|
|
client=None,
|
|
max_retries=None,
|
|
headers=None,
|
|
drop_params: Optional[bool] = None,
|
|
):
|
|
response = None
|
|
for _ in range(
|
|
2
|
|
): # if call fails due to alternating messages, retry with reformatted message
|
|
try:
|
|
openai_aclient = self._get_openai_client(
|
|
is_async=True,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=data["messages"],
|
|
api_key=openai_aclient.api_key,
|
|
additional_args={
|
|
"headers": {
|
|
"Authorization": f"Bearer {openai_aclient.api_key}"
|
|
},
|
|
"api_base": openai_aclient._base_url._uri_reference,
|
|
"acompletion": True,
|
|
"complete_input_dict": data,
|
|
},
|
|
)
|
|
|
|
headers, response = await self.make_openai_chat_completion_request(
|
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
|
)
|
|
stringified_response = response.model_dump()
|
|
logging_obj.post_call(
|
|
input=data["messages"],
|
|
api_key=api_key,
|
|
original_response=stringified_response,
|
|
additional_args={"complete_input_dict": data},
|
|
)
|
|
logging_obj.model_call_details["response_headers"] = headers
|
|
return convert_to_model_response_object(
|
|
response_object=stringified_response,
|
|
model_response_object=model_response,
|
|
hidden_params={"headers": headers},
|
|
_response_headers=headers,
|
|
)
|
|
except openai.UnprocessableEntityError as e:
|
|
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
|
|
if litellm.drop_params is True or drop_params is True:
|
|
invalid_params: List[str] = []
|
|
if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore
|
|
detail = e.body.get("detail") # type: ignore
|
|
if (
|
|
isinstance(detail, List)
|
|
and len(detail) > 0
|
|
and isinstance(detail[0], dict)
|
|
):
|
|
for error_dict in detail:
|
|
if (
|
|
error_dict.get("loc")
|
|
and isinstance(error_dict.get("loc"), list)
|
|
and len(error_dict.get("loc")) == 2
|
|
):
|
|
invalid_params.append(error_dict["loc"][1])
|
|
|
|
new_data = {}
|
|
for k, v in data.items():
|
|
if k not in invalid_params:
|
|
new_data[k] = v
|
|
data = new_data
|
|
else:
|
|
raise e
|
|
# e.message
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def streaming(
|
|
self,
|
|
logging_obj,
|
|
timeout: Union[float, httpx.Timeout],
|
|
data: dict,
|
|
model: str,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
organization: Optional[str] = None,
|
|
client=None,
|
|
max_retries=None,
|
|
headers=None,
|
|
):
|
|
openai_client = self._get_openai_client(
|
|
is_async=False,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=data["messages"],
|
|
api_key=api_key,
|
|
additional_args={
|
|
"headers": {"Authorization": f"Bearer {openai_client.api_key}"},
|
|
"api_base": openai_client._base_url._uri_reference,
|
|
"acompletion": False,
|
|
"complete_input_dict": data,
|
|
},
|
|
)
|
|
headers, response = self.make_sync_openai_chat_completion_request(
|
|
openai_client=openai_client,
|
|
data=data,
|
|
timeout=timeout,
|
|
)
|
|
|
|
logging_obj.model_call_details["response_headers"] = headers
|
|
streamwrapper = CustomStreamWrapper(
|
|
completion_stream=response,
|
|
model=model,
|
|
custom_llm_provider="openai",
|
|
logging_obj=logging_obj,
|
|
stream_options=data.get("stream_options", None),
|
|
_response_headers=headers,
|
|
)
|
|
return streamwrapper
|
|
|
|
async def async_streaming(
|
|
self,
|
|
timeout: Union[float, httpx.Timeout],
|
|
data: dict,
|
|
model: str,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
organization: Optional[str] = None,
|
|
client=None,
|
|
max_retries=None,
|
|
headers=None,
|
|
drop_params: Optional[bool] = None,
|
|
):
|
|
response = None
|
|
for _ in range(2):
|
|
try:
|
|
openai_aclient = self._get_openai_client(
|
|
is_async=True,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=data["messages"],
|
|
api_key=api_key,
|
|
additional_args={
|
|
"headers": headers,
|
|
"api_base": api_base,
|
|
"acompletion": True,
|
|
"complete_input_dict": data,
|
|
},
|
|
)
|
|
|
|
headers, response = await self.make_openai_chat_completion_request(
|
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
|
)
|
|
logging_obj.model_call_details["response_headers"] = headers
|
|
streamwrapper = CustomStreamWrapper(
|
|
completion_stream=response,
|
|
model=model,
|
|
custom_llm_provider="openai",
|
|
logging_obj=logging_obj,
|
|
stream_options=data.get("stream_options", None),
|
|
_response_headers=headers,
|
|
)
|
|
return streamwrapper
|
|
except openai.UnprocessableEntityError as e:
|
|
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
|
|
if litellm.drop_params is True or drop_params is True:
|
|
invalid_params: List[str] = []
|
|
if e.body is not None and isinstance(e.body, dict) and e.body.get("detail"): # type: ignore
|
|
detail = e.body.get("detail") # type: ignore
|
|
if (
|
|
isinstance(detail, List)
|
|
and len(detail) > 0
|
|
and isinstance(detail[0], dict)
|
|
):
|
|
for error_dict in detail:
|
|
if (
|
|
error_dict.get("loc")
|
|
and isinstance(error_dict.get("loc"), list)
|
|
and len(error_dict.get("loc")) == 2
|
|
):
|
|
invalid_params.append(error_dict["loc"][1])
|
|
|
|
new_data = {}
|
|
for k, v in data.items():
|
|
if k not in invalid_params:
|
|
new_data[k] = v
|
|
data = new_data
|
|
else:
|
|
raise e
|
|
except (
|
|
Exception
|
|
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
|
|
if isinstance(e, OpenAIError):
|
|
raise e
|
|
if response is not None and hasattr(response, "text"):
|
|
raise OpenAIError(
|
|
status_code=500,
|
|
message=f"{str(e)}\n\nOriginal Response: {response.text}",
|
|
)
|
|
else:
|
|
if type(e).__name__ == "ReadTimeout":
|
|
raise OpenAIError(
|
|
status_code=408, message=f"{type(e).__name__}"
|
|
)
|
|
elif hasattr(e, "status_code"):
|
|
raise OpenAIError(status_code=e.status_code, message=str(e))
|
|
else:
|
|
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
|
|
|
# Embedding
|
|
async def make_openai_embedding_request(
|
|
self,
|
|
openai_aclient: AsyncOpenAI,
|
|
data: dict,
|
|
timeout: Union[float, httpx.Timeout],
|
|
):
|
|
"""
|
|
Helper to:
|
|
- call embeddings.create.with_raw_response when litellm.return_response_headers is True
|
|
- call embeddings.create by default
|
|
"""
|
|
try:
|
|
raw_response = await openai_aclient.embeddings.with_raw_response.create(
|
|
**data, timeout=timeout
|
|
) # type: ignore
|
|
headers = dict(raw_response.headers)
|
|
response = raw_response.parse()
|
|
return headers, response
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def make_sync_openai_embedding_request(
|
|
self,
|
|
openai_client: OpenAI,
|
|
data: dict,
|
|
timeout: Union[float, httpx.Timeout],
|
|
):
|
|
"""
|
|
Helper to:
|
|
- call embeddings.create.with_raw_response when litellm.return_response_headers is True
|
|
- call embeddings.create by default
|
|
"""
|
|
try:
|
|
raw_response = openai_client.embeddings.with_raw_response.create(
|
|
**data, timeout=timeout
|
|
) # type: ignore
|
|
|
|
headers = dict(raw_response.headers)
|
|
response = raw_response.parse()
|
|
return headers, response
|
|
except Exception as e:
|
|
raise e
|
|
|
|
async def aembedding(
|
|
self,
|
|
input: list,
|
|
data: dict,
|
|
model_response: litellm.utils.EmbeddingResponse,
|
|
timeout: float,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
client: Optional[AsyncOpenAI] = None,
|
|
max_retries=None,
|
|
):
|
|
response = None
|
|
try:
|
|
openai_aclient = self._get_openai_client(
|
|
is_async=True,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
)
|
|
headers, response = await self.make_openai_embedding_request(
|
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
|
)
|
|
logging_obj.model_call_details["response_headers"] = headers
|
|
stringified_response = response.model_dump()
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=input,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=stringified_response,
|
|
)
|
|
return convert_to_model_response_object(
|
|
response_object=stringified_response,
|
|
model_response_object=model_response,
|
|
response_type="embedding",
|
|
_response_headers=headers,
|
|
) # type: ignore
|
|
except OpenAIError as e:
|
|
raise e
|
|
except Exception as e:
|
|
status_code = getattr(e, "status_code", 500)
|
|
error_headers = getattr(e, "headers", None)
|
|
raise OpenAIError(
|
|
status_code=status_code, message=str(e), headers=error_headers
|
|
)
|
|
|
|
def embedding(
|
|
self,
|
|
model: str,
|
|
input: list,
|
|
timeout: float,
|
|
logging_obj,
|
|
model_response: litellm.utils.EmbeddingResponse,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
optional_params=None,
|
|
client=None,
|
|
aembedding=None,
|
|
):
|
|
super().embedding()
|
|
exception_mapping_worked = False
|
|
try:
|
|
model = model
|
|
data = {"model": model, "input": input, **optional_params}
|
|
max_retries = data.pop("max_retries", 2)
|
|
if not isinstance(max_retries, int):
|
|
raise OpenAIError(status_code=422, message="max retries must be an int")
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=input,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data, "api_base": api_base},
|
|
)
|
|
|
|
if aembedding is True:
|
|
response = self.aembedding(
|
|
data=data,
|
|
input=input,
|
|
logging_obj=logging_obj,
|
|
model_response=model_response,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
timeout=timeout,
|
|
client=client,
|
|
max_retries=max_retries,
|
|
)
|
|
return response
|
|
|
|
openai_client = self._get_openai_client(
|
|
is_async=False,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
)
|
|
|
|
## embedding CALL
|
|
headers: Optional[Dict] = None
|
|
headers, sync_embedding_response = self.make_sync_openai_embedding_request(
|
|
openai_client=openai_client, data=data, timeout=timeout
|
|
) # type: ignore
|
|
|
|
## LOGGING
|
|
logging_obj.model_call_details["response_headers"] = headers
|
|
logging_obj.post_call(
|
|
input=input,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=sync_embedding_response,
|
|
)
|
|
return convert_to_model_response_object(
|
|
response_object=sync_embedding_response.model_dump(),
|
|
model_response_object=model_response,
|
|
_response_headers=headers,
|
|
response_type="embedding",
|
|
) # type: ignore
|
|
except OpenAIError as e:
|
|
raise e
|
|
except Exception as e:
|
|
status_code = getattr(e, "status_code", 500)
|
|
error_headers = getattr(e, "headers", None)
|
|
raise OpenAIError(
|
|
status_code=status_code, message=str(e), headers=error_headers
|
|
)
|
|
|
|
async def aimage_generation(
|
|
self,
|
|
prompt: str,
|
|
data: dict,
|
|
model_response: ModelResponse,
|
|
timeout: float,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
client=None,
|
|
max_retries=None,
|
|
logging_obj=None,
|
|
):
|
|
response = None
|
|
try:
|
|
|
|
openai_aclient = self._get_openai_client(
|
|
is_async=True,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore
|
|
stringified_response = response.model_dump()
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=prompt,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=stringified_response,
|
|
)
|
|
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
|
except Exception as e:
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=input,
|
|
api_key=api_key,
|
|
original_response=str(e),
|
|
)
|
|
raise e
|
|
|
|
def image_generation(
|
|
self,
|
|
model: Optional[str],
|
|
prompt: str,
|
|
timeout: float,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
model_response: Optional[litellm.utils.ImageResponse] = None,
|
|
logging_obj=None,
|
|
optional_params=None,
|
|
client=None,
|
|
aimg_generation=None,
|
|
):
|
|
exception_mapping_worked = False
|
|
try:
|
|
model = model
|
|
data = {"model": model, "prompt": prompt, **optional_params}
|
|
max_retries = data.pop("max_retries", 2)
|
|
if not isinstance(max_retries, int):
|
|
raise OpenAIError(status_code=422, message="max retries must be an int")
|
|
|
|
if aimg_generation == True:
|
|
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
|
return response
|
|
|
|
openai_client = self._get_openai_client(
|
|
is_async=False,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
)
|
|
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=prompt,
|
|
api_key=openai_client.api_key,
|
|
additional_args={
|
|
"headers": {"Authorization": f"Bearer {openai_client.api_key}"},
|
|
"api_base": openai_client._base_url._uri_reference,
|
|
"acompletion": True,
|
|
"complete_input_dict": data,
|
|
},
|
|
)
|
|
|
|
## COMPLETION CALL
|
|
response = openai_client.images.generate(**data, timeout=timeout) # type: ignore
|
|
response = response.model_dump() # type: ignore
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=prompt,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=response,
|
|
)
|
|
# return response
|
|
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
|
except OpenAIError as e:
|
|
|
|
exception_mapping_worked = True
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=prompt,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=str(e),
|
|
)
|
|
raise e
|
|
except Exception as e:
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=prompt,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=str(e),
|
|
)
|
|
if hasattr(e, "status_code"):
|
|
raise OpenAIError(status_code=e.status_code, message=str(e))
|
|
else:
|
|
raise OpenAIError(status_code=500, message=str(e))
|
|
|
|
# Audio Transcriptions
|
|
async def make_openai_audio_transcriptions_request(
|
|
self,
|
|
openai_aclient: AsyncOpenAI,
|
|
data: dict,
|
|
timeout: Union[float, httpx.Timeout],
|
|
):
|
|
"""
|
|
Helper to:
|
|
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
|
|
- call openai_aclient.audio.transcriptions.create by default
|
|
"""
|
|
try:
|
|
if litellm.return_response_headers is True:
|
|
raw_response = (
|
|
await openai_aclient.audio.transcriptions.with_raw_response.create(
|
|
**data, timeout=timeout
|
|
)
|
|
) # type: ignore
|
|
headers = dict(raw_response.headers)
|
|
response = raw_response.parse()
|
|
return headers, response
|
|
else:
|
|
response = await openai_aclient.audio.transcriptions.create(**data, timeout=timeout) # type: ignore
|
|
return None, response
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def make_sync_openai_audio_transcriptions_request(
|
|
self,
|
|
openai_client: OpenAI,
|
|
data: dict,
|
|
timeout: Union[float, httpx.Timeout],
|
|
):
|
|
"""
|
|
Helper to:
|
|
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
|
|
- call openai_aclient.audio.transcriptions.create by default
|
|
"""
|
|
try:
|
|
if litellm.return_response_headers is True:
|
|
raw_response = (
|
|
openai_client.audio.transcriptions.with_raw_response.create(
|
|
**data, timeout=timeout
|
|
)
|
|
) # type: ignore
|
|
headers = dict(raw_response.headers)
|
|
response = raw_response.parse()
|
|
return headers, response
|
|
else:
|
|
response = openai_client.audio.transcriptions.create(**data, timeout=timeout) # type: ignore
|
|
return None, response
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def audio_transcriptions(
|
|
self,
|
|
model: str,
|
|
audio_file: BinaryIO,
|
|
optional_params: dict,
|
|
model_response: TranscriptionResponse,
|
|
timeout: float,
|
|
max_retries: int,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
client=None,
|
|
logging_obj=None,
|
|
atranscription: bool = False,
|
|
):
|
|
data = {"model": model, "file": audio_file, **optional_params}
|
|
if atranscription is True:
|
|
return self.async_audio_transcriptions(
|
|
audio_file=audio_file,
|
|
data=data,
|
|
model_response=model_response,
|
|
timeout=timeout,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
client=client,
|
|
max_retries=max_retries,
|
|
logging_obj=logging_obj,
|
|
)
|
|
|
|
openai_client = self._get_openai_client(
|
|
is_async=False,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
)
|
|
_, response = self.make_sync_openai_audio_transcriptions_request(
|
|
openai_client=openai_client,
|
|
data=data,
|
|
timeout=timeout,
|
|
)
|
|
|
|
if isinstance(response, BaseModel):
|
|
stringified_response = response.model_dump()
|
|
else:
|
|
stringified_response = TranscriptionResponse(text=response).model_dump()
|
|
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=audio_file.name,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=stringified_response,
|
|
)
|
|
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
|
|
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
|
|
return final_response
|
|
|
|
async def async_audio_transcriptions(
|
|
self,
|
|
audio_file: BinaryIO,
|
|
data: dict,
|
|
model_response: TranscriptionResponse,
|
|
timeout: float,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
client=None,
|
|
max_retries=None,
|
|
):
|
|
try:
|
|
openai_aclient = self._get_openai_client(
|
|
is_async=True,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
)
|
|
|
|
headers, response = await self.make_openai_audio_transcriptions_request(
|
|
openai_aclient=openai_aclient,
|
|
data=data,
|
|
timeout=timeout,
|
|
)
|
|
logging_obj.model_call_details["response_headers"] = headers
|
|
if isinstance(response, BaseModel):
|
|
stringified_response = response.model_dump()
|
|
else:
|
|
stringified_response = TranscriptionResponse(text=response).model_dump()
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=audio_file.name,
|
|
api_key=api_key,
|
|
additional_args={"complete_input_dict": data},
|
|
original_response=stringified_response,
|
|
)
|
|
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
|
|
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
|
|
except Exception as e:
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=input,
|
|
api_key=api_key,
|
|
original_response=str(e),
|
|
)
|
|
raise e
|
|
|
|
def audio_speech(
|
|
self,
|
|
model: str,
|
|
input: str,
|
|
voice: str,
|
|
optional_params: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
organization: Optional[str],
|
|
project: Optional[str],
|
|
max_retries: int,
|
|
timeout: Union[float, httpx.Timeout],
|
|
aspeech: Optional[bool] = None,
|
|
client=None,
|
|
) -> HttpxBinaryResponseContent:
|
|
|
|
if aspeech is not None and aspeech is True:
|
|
return self.async_audio_speech(
|
|
model=model,
|
|
input=input,
|
|
voice=voice,
|
|
optional_params=optional_params,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
organization=organization,
|
|
project=project,
|
|
max_retries=max_retries,
|
|
timeout=timeout,
|
|
client=client,
|
|
) # type: ignore
|
|
|
|
openai_client = self._get_openai_client(
|
|
is_async=False,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
)
|
|
|
|
response = openai_client.audio.speech.create(
|
|
model=model,
|
|
voice=voice, # type: ignore
|
|
input=input,
|
|
**optional_params,
|
|
)
|
|
return response
|
|
|
|
async def async_audio_speech(
|
|
self,
|
|
model: str,
|
|
input: str,
|
|
voice: str,
|
|
optional_params: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
organization: Optional[str],
|
|
project: Optional[str],
|
|
max_retries: int,
|
|
timeout: Union[float, httpx.Timeout],
|
|
client=None,
|
|
) -> HttpxBinaryResponseContent:
|
|
|
|
openai_client = self._get_openai_client(
|
|
is_async=True,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_client.audio.speech.create(
|
|
model=model,
|
|
voice=voice, # type: ignore
|
|
input=input,
|
|
**optional_params,
|
|
)
|
|
|
|
return response
|
|
|
|
async def ahealth_check(
|
|
self,
|
|
model: Optional[str],
|
|
api_key: str,
|
|
timeout: float,
|
|
mode: str,
|
|
messages: Optional[list] = None,
|
|
input: Optional[list] = None,
|
|
prompt: Optional[str] = None,
|
|
organization: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
):
|
|
client = AsyncOpenAI(
|
|
api_key=api_key,
|
|
timeout=timeout,
|
|
organization=organization,
|
|
base_url=api_base,
|
|
)
|
|
if model is None and mode != "image_generation":
|
|
raise Exception("model is not set")
|
|
|
|
completion = None
|
|
|
|
if mode == "completion":
|
|
completion = await client.completions.with_raw_response.create(
|
|
model=model, # type: ignore
|
|
prompt=prompt, # type: ignore
|
|
)
|
|
elif mode == "chat":
|
|
if messages is None:
|
|
raise Exception("messages is not set")
|
|
completion = await client.chat.completions.with_raw_response.create(
|
|
model=model, # type: ignore
|
|
messages=messages, # type: ignore
|
|
)
|
|
elif mode == "embedding":
|
|
if input is None:
|
|
raise Exception("input is not set")
|
|
completion = await client.embeddings.with_raw_response.create(
|
|
model=model, # type: ignore
|
|
input=input, # type: ignore
|
|
)
|
|
elif mode == "image_generation":
|
|
if prompt is None:
|
|
raise Exception("prompt is not set")
|
|
completion = await client.images.with_raw_response.generate(
|
|
model=model, # type: ignore
|
|
prompt=prompt, # type: ignore
|
|
)
|
|
elif mode == "audio_transcription":
|
|
# Get the current directory of the file being run
|
|
pwd = os.path.dirname(os.path.realpath(__file__))
|
|
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
|
|
audio_file = open(file_path, "rb")
|
|
completion = await client.audio.transcriptions.with_raw_response.create(
|
|
file=audio_file,
|
|
model=model, # type: ignore
|
|
prompt=prompt, # type: ignore
|
|
)
|
|
elif mode == "audio_speech":
|
|
# Get the current directory of the file being run
|
|
completion = await client.audio.speech.with_raw_response.create(
|
|
model=model, # type: ignore
|
|
input=prompt, # type: ignore
|
|
voice="alloy",
|
|
)
|
|
else:
|
|
raise ValueError("mode not set, passed in mode: " + mode)
|
|
response = {}
|
|
|
|
if completion is None or not hasattr(completion, "headers"):
|
|
raise Exception("invalid completion response")
|
|
|
|
if (
|
|
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
|
|
): # not provided for dall-e requests
|
|
response["x-ratelimit-remaining-requests"] = completion.headers[
|
|
"x-ratelimit-remaining-requests"
|
|
]
|
|
|
|
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
|
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
|
"x-ratelimit-remaining-tokens"
|
|
]
|
|
return response
|
|
|
|
|
|
class OpenAITextCompletion(BaseLLM):
|
|
_client_session: httpx.Client
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._client_session = self.create_client_session()
|
|
|
|
def validate_environment(self, api_key):
|
|
headers = {
|
|
"content-type": "application/json",
|
|
}
|
|
if api_key:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
return headers
|
|
|
|
def completion(
|
|
self,
|
|
model_response: ModelResponse,
|
|
api_key: str,
|
|
model: str,
|
|
messages: list,
|
|
timeout: float,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
print_verbose: Optional[Callable] = None,
|
|
api_base: Optional[str] = None,
|
|
acompletion: bool = False,
|
|
optional_params=None,
|
|
litellm_params=None,
|
|
logger_fn=None,
|
|
client=None,
|
|
organization: Optional[str] = None,
|
|
headers: Optional[dict] = None,
|
|
):
|
|
super().completion()
|
|
exception_mapping_worked = False
|
|
try:
|
|
if headers is None:
|
|
headers = self.validate_environment(api_key=api_key)
|
|
if model is None or messages is None:
|
|
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
|
|
|
if (
|
|
len(messages) > 0
|
|
and "content" in messages[0]
|
|
and type(messages[0]["content"]) == list
|
|
):
|
|
prompt = messages[0]["content"]
|
|
else:
|
|
prompt = [message["content"] for message in messages] # type: ignore
|
|
|
|
# don't send max retries to the api, if set
|
|
|
|
data = {"model": model, "prompt": prompt, **optional_params}
|
|
max_retries = data.pop("max_retries", 2)
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=messages,
|
|
api_key=api_key,
|
|
additional_args={
|
|
"headers": headers,
|
|
"api_base": api_base,
|
|
"complete_input_dict": data,
|
|
},
|
|
)
|
|
if acompletion is True:
|
|
if optional_params.get("stream", False):
|
|
return self.async_streaming(
|
|
logging_obj=logging_obj,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
data=data,
|
|
headers=headers,
|
|
model_response=model_response,
|
|
model=model,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
organization=organization,
|
|
)
|
|
else:
|
|
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
|
|
elif optional_params.get("stream", False):
|
|
return self.streaming(
|
|
logging_obj=logging_obj,
|
|
api_base=api_base,
|
|
api_key=api_key,
|
|
data=data,
|
|
headers=headers,
|
|
model_response=model_response,
|
|
model=model,
|
|
timeout=timeout,
|
|
max_retries=max_retries, # type: ignore
|
|
client=client,
|
|
organization=organization,
|
|
)
|
|
else:
|
|
if client is None:
|
|
openai_client = OpenAI(
|
|
api_key=api_key,
|
|
base_url=api_base,
|
|
http_client=litellm.client_session,
|
|
timeout=timeout,
|
|
max_retries=max_retries, # type: ignore
|
|
organization=organization,
|
|
)
|
|
else:
|
|
openai_client = client
|
|
|
|
raw_response = openai_client.completions.with_raw_response.create(**data) # type: ignore
|
|
response = raw_response.parse()
|
|
response_json = response.model_dump()
|
|
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=prompt,
|
|
api_key=api_key,
|
|
original_response=response_json,
|
|
additional_args={
|
|
"headers": headers,
|
|
"api_base": api_base,
|
|
},
|
|
)
|
|
|
|
## RESPONSE OBJECT
|
|
return TextCompletionResponse(**response_json)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
async def acompletion(
|
|
self,
|
|
logging_obj,
|
|
api_base: str,
|
|
data: dict,
|
|
headers: dict,
|
|
model_response: ModelResponse,
|
|
prompt: str,
|
|
api_key: str,
|
|
model: str,
|
|
timeout: float,
|
|
max_retries=None,
|
|
organization: Optional[str] = None,
|
|
client=None,
|
|
):
|
|
try:
|
|
if client is None:
|
|
openai_aclient = AsyncOpenAI(
|
|
api_key=api_key,
|
|
base_url=api_base,
|
|
http_client=litellm.aclient_session,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
)
|
|
else:
|
|
openai_aclient = client
|
|
|
|
raw_response = await openai_aclient.completions.with_raw_response.create(
|
|
**data
|
|
)
|
|
response = raw_response.parse()
|
|
response_json = response.model_dump()
|
|
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=prompt,
|
|
api_key=api_key,
|
|
original_response=response,
|
|
additional_args={
|
|
"headers": headers,
|
|
"api_base": api_base,
|
|
},
|
|
)
|
|
## RESPONSE OBJECT
|
|
response_obj = TextCompletionResponse(**response_json)
|
|
response_obj._hidden_params.original_response = json.dumps(response_json)
|
|
return response_obj
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def streaming(
|
|
self,
|
|
logging_obj,
|
|
api_key: str,
|
|
data: dict,
|
|
headers: dict,
|
|
model_response: ModelResponse,
|
|
model: str,
|
|
timeout: float,
|
|
api_base: Optional[str] = None,
|
|
max_retries=None,
|
|
client=None,
|
|
organization=None,
|
|
):
|
|
|
|
if client is None:
|
|
openai_client = OpenAI(
|
|
api_key=api_key,
|
|
base_url=api_base,
|
|
http_client=litellm.client_session,
|
|
timeout=timeout,
|
|
max_retries=max_retries, # type: ignore
|
|
organization=organization,
|
|
)
|
|
else:
|
|
openai_client = client
|
|
|
|
try:
|
|
response = openai_client.completions.with_raw_response.create(**data)
|
|
except Exception as e:
|
|
status_code = getattr(e, "status_code", 500)
|
|
error_headers = getattr(e, "headers", None)
|
|
raise OpenAIError(
|
|
status_code=status_code, message=str(e), headers=error_headers
|
|
)
|
|
streamwrapper = CustomStreamWrapper(
|
|
completion_stream=response,
|
|
model=model,
|
|
custom_llm_provider="text-completion-openai",
|
|
logging_obj=logging_obj,
|
|
stream_options=data.get("stream_options", None),
|
|
)
|
|
|
|
for chunk in streamwrapper:
|
|
yield chunk
|
|
|
|
async def async_streaming(
|
|
self,
|
|
logging_obj,
|
|
api_key: str,
|
|
data: dict,
|
|
headers: dict,
|
|
model_response: ModelResponse,
|
|
model: str,
|
|
timeout: float,
|
|
api_base: Optional[str] = None,
|
|
client=None,
|
|
max_retries=None,
|
|
organization=None,
|
|
):
|
|
if client is None:
|
|
openai_client = AsyncOpenAI(
|
|
api_key=api_key,
|
|
base_url=api_base,
|
|
http_client=litellm.aclient_session,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
)
|
|
else:
|
|
openai_client = client
|
|
|
|
response = await openai_client.completions.with_raw_response.create(**data)
|
|
|
|
streamwrapper = CustomStreamWrapper(
|
|
completion_stream=response,
|
|
model=model,
|
|
custom_llm_provider="text-completion-openai",
|
|
logging_obj=logging_obj,
|
|
stream_options=data.get("stream_options", None),
|
|
)
|
|
|
|
async for transformed_chunk in streamwrapper:
|
|
yield transformed_chunk
|
|
|
|
|
|
class OpenAIFilesAPI(BaseLLM):
|
|
"""
|
|
OpenAI methods to support for batches
|
|
- create_file()
|
|
- retrieve_file()
|
|
- list_files()
|
|
- delete_file()
|
|
- file_content()
|
|
- update_file()
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def get_openai_client(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
_is_async: bool = False,
|
|
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
|
|
received_args = locals()
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
|
|
if client is None:
|
|
data = {}
|
|
for k, v in received_args.items():
|
|
if k == "self" or k == "client" or k == "_is_async":
|
|
pass
|
|
elif k == "api_base" and v is not None:
|
|
data["base_url"] = v
|
|
elif v is not None:
|
|
data[k] = v
|
|
if _is_async is True:
|
|
openai_client = AsyncOpenAI(**data)
|
|
else:
|
|
openai_client = OpenAI(**data) # type: ignore
|
|
else:
|
|
openai_client = client
|
|
|
|
return openai_client
|
|
|
|
async def acreate_file(
|
|
self,
|
|
create_file_data: CreateFileRequest,
|
|
openai_client: AsyncOpenAI,
|
|
) -> FileObject:
|
|
response = await openai_client.files.create(**create_file_data)
|
|
return response
|
|
|
|
def create_file(
|
|
self,
|
|
_is_async: bool,
|
|
create_file_data: CreateFileRequest,
|
|
api_base: str,
|
|
api_key: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.acreate_file( # type: ignore
|
|
create_file_data=create_file_data, openai_client=openai_client
|
|
)
|
|
response = openai_client.files.create(**create_file_data)
|
|
return response
|
|
|
|
async def afile_content(
|
|
self,
|
|
file_content_request: FileContentRequest,
|
|
openai_client: AsyncOpenAI,
|
|
) -> HttpxBinaryResponseContent:
|
|
response = await openai_client.files.content(**file_content_request)
|
|
return response
|
|
|
|
def file_content(
|
|
self,
|
|
_is_async: bool,
|
|
file_content_request: FileContentRequest,
|
|
api_base: str,
|
|
api_key: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
) -> Union[
|
|
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
|
]:
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.afile_content( # type: ignore
|
|
file_content_request=file_content_request,
|
|
openai_client=openai_client,
|
|
)
|
|
response = openai_client.files.content(**file_content_request)
|
|
|
|
return response
|
|
|
|
async def aretrieve_file(
|
|
self,
|
|
file_id: str,
|
|
openai_client: AsyncOpenAI,
|
|
) -> FileObject:
|
|
response = await openai_client.files.retrieve(file_id=file_id)
|
|
return response
|
|
|
|
def retrieve_file(
|
|
self,
|
|
_is_async: bool,
|
|
file_id: str,
|
|
api_base: str,
|
|
api_key: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
):
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.aretrieve_file( # type: ignore
|
|
file_id=file_id,
|
|
openai_client=openai_client,
|
|
)
|
|
response = openai_client.files.retrieve(file_id=file_id)
|
|
|
|
return response
|
|
|
|
async def adelete_file(
|
|
self,
|
|
file_id: str,
|
|
openai_client: AsyncOpenAI,
|
|
) -> FileDeleted:
|
|
response = await openai_client.files.delete(file_id=file_id)
|
|
return response
|
|
|
|
def delete_file(
|
|
self,
|
|
_is_async: bool,
|
|
file_id: str,
|
|
api_base: str,
|
|
api_key: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
):
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.adelete_file( # type: ignore
|
|
file_id=file_id,
|
|
openai_client=openai_client,
|
|
)
|
|
response = openai_client.files.delete(file_id=file_id)
|
|
|
|
return response
|
|
|
|
async def alist_files(
|
|
self,
|
|
openai_client: AsyncOpenAI,
|
|
purpose: Optional[str] = None,
|
|
):
|
|
if isinstance(purpose, str):
|
|
response = await openai_client.files.list(purpose=purpose)
|
|
else:
|
|
response = await openai_client.files.list()
|
|
return response
|
|
|
|
def list_files(
|
|
self,
|
|
_is_async: bool,
|
|
api_base: str,
|
|
api_key: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
purpose: Optional[str] = None,
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
):
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.alist_files( # type: ignore
|
|
purpose=purpose,
|
|
openai_client=openai_client,
|
|
)
|
|
|
|
if isinstance(purpose, str):
|
|
response = openai_client.files.list(purpose=purpose)
|
|
else:
|
|
response = openai_client.files.list()
|
|
|
|
return response
|
|
|
|
|
|
class OpenAIBatchesAPI(BaseLLM):
|
|
"""
|
|
OpenAI methods to support for batches
|
|
- create_batch()
|
|
- retrieve_batch()
|
|
- cancel_batch()
|
|
- list_batch()
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def get_openai_client(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
_is_async: bool = False,
|
|
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
|
|
received_args = locals()
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
|
|
if client is None:
|
|
data = {}
|
|
for k, v in received_args.items():
|
|
if k == "self" or k == "client" or k == "_is_async":
|
|
pass
|
|
elif k == "api_base" and v is not None:
|
|
data["base_url"] = v
|
|
elif v is not None:
|
|
data[k] = v
|
|
if _is_async is True:
|
|
openai_client = AsyncOpenAI(**data)
|
|
else:
|
|
openai_client = OpenAI(**data) # type: ignore
|
|
else:
|
|
openai_client = client
|
|
|
|
return openai_client
|
|
|
|
async def acreate_batch(
|
|
self,
|
|
create_batch_data: CreateBatchRequest,
|
|
openai_client: AsyncOpenAI,
|
|
) -> Batch:
|
|
response = await openai_client.batches.create(**create_batch_data)
|
|
return response
|
|
|
|
def create_batch(
|
|
self,
|
|
_is_async: bool,
|
|
create_batch_data: CreateBatchRequest,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
|
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.acreate_batch( # type: ignore
|
|
create_batch_data=create_batch_data, openai_client=openai_client
|
|
)
|
|
response = openai_client.batches.create(**create_batch_data)
|
|
return response
|
|
|
|
async def aretrieve_batch(
|
|
self,
|
|
retrieve_batch_data: RetrieveBatchRequest,
|
|
openai_client: AsyncOpenAI,
|
|
) -> Batch:
|
|
verbose_logger.debug("retrieving batch, args= %s", retrieve_batch_data)
|
|
response = await openai_client.batches.retrieve(**retrieve_batch_data)
|
|
return response
|
|
|
|
def retrieve_batch(
|
|
self,
|
|
_is_async: bool,
|
|
retrieve_batch_data: RetrieveBatchRequest,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI] = None,
|
|
):
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.aretrieve_batch( # type: ignore
|
|
retrieve_batch_data=retrieve_batch_data, openai_client=openai_client
|
|
)
|
|
response = openai_client.batches.retrieve(**retrieve_batch_data)
|
|
return response
|
|
|
|
def cancel_batch(
|
|
self,
|
|
_is_async: bool,
|
|
cancel_batch_data: CancelBatchRequest,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI] = None,
|
|
):
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
response = openai_client.batches.cancel(**cancel_batch_data)
|
|
return response
|
|
|
|
async def alist_batches(
|
|
self,
|
|
openai_client: AsyncOpenAI,
|
|
after: Optional[str] = None,
|
|
limit: Optional[int] = None,
|
|
):
|
|
verbose_logger.debug("listing batches, after= %s, limit= %s", after, limit)
|
|
response = await openai_client.batches.list(after=after, limit=limit) # type: ignore
|
|
return response
|
|
|
|
def list_batches(
|
|
self,
|
|
_is_async: bool,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
after: Optional[str] = None,
|
|
limit: Optional[int] = None,
|
|
client: Optional[OpenAI] = None,
|
|
):
|
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
_is_async=_is_async,
|
|
)
|
|
if openai_client is None:
|
|
raise ValueError(
|
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
|
)
|
|
|
|
if _is_async is True:
|
|
if not isinstance(openai_client, AsyncOpenAI):
|
|
raise ValueError(
|
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
|
)
|
|
return self.alist_batches( # type: ignore
|
|
openai_client=openai_client, after=after, limit=limit
|
|
)
|
|
response = openai_client.batches.list(after=after, limit=limit) # type: ignore
|
|
return response
|
|
|
|
|
|
class OpenAIAssistantsAPI(BaseLLM):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def get_openai_client(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI] = None,
|
|
) -> OpenAI:
|
|
received_args = locals()
|
|
if client is None:
|
|
data = {}
|
|
for k, v in received_args.items():
|
|
if k == "self" or k == "client":
|
|
pass
|
|
elif k == "api_base" and v is not None:
|
|
data["base_url"] = v
|
|
elif v is not None:
|
|
data[k] = v
|
|
openai_client = OpenAI(**data) # type: ignore
|
|
else:
|
|
openai_client = client
|
|
|
|
return openai_client
|
|
|
|
def async_get_openai_client(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI] = None,
|
|
) -> AsyncOpenAI:
|
|
received_args = locals()
|
|
if client is None:
|
|
data = {}
|
|
for k, v in received_args.items():
|
|
if k == "self" or k == "client":
|
|
pass
|
|
elif k == "api_base" and v is not None:
|
|
data["base_url"] = v
|
|
elif v is not None:
|
|
data[k] = v
|
|
openai_client = AsyncOpenAI(**data) # type: ignore
|
|
else:
|
|
openai_client = client
|
|
|
|
return openai_client
|
|
|
|
### ASSISTANTS ###
|
|
|
|
async def async_get_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
) -> AsyncCursorPage[Assistant]:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_client.beta.assistants.list()
|
|
|
|
return response
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def get_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
aget_assistants: Literal[True],
|
|
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
|
|
...
|
|
|
|
@overload
|
|
def get_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI],
|
|
aget_assistants: Optional[Literal[False]],
|
|
) -> SyncCursorPage[Assistant]:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def get_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client=None,
|
|
aget_assistants=None,
|
|
):
|
|
if aget_assistants is not None and aget_assistants == True:
|
|
return self.async_get_assistants(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = openai_client.beta.assistants.list()
|
|
|
|
return response
|
|
|
|
# Create Assistant
|
|
async def async_create_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
create_assistant_data: dict,
|
|
) -> Assistant:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_client.beta.assistants.create(**create_assistant_data)
|
|
|
|
return response
|
|
|
|
def create_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
create_assistant_data: dict,
|
|
client=None,
|
|
async_create_assistants=None,
|
|
):
|
|
if async_create_assistants is not None and async_create_assistants == True:
|
|
return self.async_create_assistants(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
create_assistant_data=create_assistant_data,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = openai_client.beta.assistants.create(**create_assistant_data)
|
|
return response
|
|
|
|
# Delete Assistant
|
|
async def async_delete_assistant(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
assistant_id: str,
|
|
) -> AssistantDeleted:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_client.beta.assistants.delete(assistant_id=assistant_id)
|
|
|
|
return response
|
|
|
|
def delete_assistant(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
assistant_id: str,
|
|
client=None,
|
|
async_delete_assistants=None,
|
|
):
|
|
if async_delete_assistants is not None and async_delete_assistants == True:
|
|
return self.async_delete_assistant(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
assistant_id=assistant_id,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = openai_client.beta.assistants.delete(assistant_id=assistant_id)
|
|
return response
|
|
|
|
### MESSAGES ###
|
|
|
|
async def a_add_message(
|
|
self,
|
|
thread_id: str,
|
|
message_data: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI] = None,
|
|
) -> OpenAIMessage:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
|
|
thread_id, **message_data # type: ignore
|
|
)
|
|
|
|
response_obj: Optional[OpenAIMessage] = None
|
|
if getattr(thread_message, "status", None) is None:
|
|
thread_message.status = "completed"
|
|
response_obj = OpenAIMessage(**thread_message.dict())
|
|
else:
|
|
response_obj = OpenAIMessage(**thread_message.dict())
|
|
return response_obj
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def add_message(
|
|
self,
|
|
thread_id: str,
|
|
message_data: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
a_add_message: Literal[True],
|
|
) -> Coroutine[None, None, OpenAIMessage]:
|
|
...
|
|
|
|
@overload
|
|
def add_message(
|
|
self,
|
|
thread_id: str,
|
|
message_data: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI],
|
|
a_add_message: Optional[Literal[False]],
|
|
) -> OpenAIMessage:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def add_message(
|
|
self,
|
|
thread_id: str,
|
|
message_data: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client=None,
|
|
a_add_message: Optional[bool] = None,
|
|
):
|
|
if a_add_message is not None and a_add_message == True:
|
|
return self.a_add_message(
|
|
thread_id=thread_id,
|
|
message_data=message_data,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
|
|
thread_id, **message_data # type: ignore
|
|
)
|
|
|
|
response_obj: Optional[OpenAIMessage] = None
|
|
if getattr(thread_message, "status", None) is None:
|
|
thread_message.status = "completed"
|
|
response_obj = OpenAIMessage(**thread_message.dict())
|
|
else:
|
|
response_obj = OpenAIMessage(**thread_message.dict())
|
|
return response_obj
|
|
|
|
async def async_get_messages(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI] = None,
|
|
) -> AsyncCursorPage[OpenAIMessage]:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
|
|
|
|
return response
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def get_messages(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
aget_messages: Literal[True],
|
|
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
|
|
...
|
|
|
|
@overload
|
|
def get_messages(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI],
|
|
aget_messages: Optional[Literal[False]],
|
|
) -> SyncCursorPage[OpenAIMessage]:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def get_messages(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client=None,
|
|
aget_messages=None,
|
|
):
|
|
if aget_messages is not None and aget_messages == True:
|
|
return self.async_get_messages(
|
|
thread_id=thread_id,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
|
|
|
|
return response
|
|
|
|
### THREADS ###
|
|
|
|
async def async_create_thread(
|
|
self,
|
|
metadata: Optional[dict],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
) -> Thread:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
data = {}
|
|
if messages is not None:
|
|
data["messages"] = messages # type: ignore
|
|
if metadata is not None:
|
|
data["metadata"] = metadata # type: ignore
|
|
|
|
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
|
|
|
|
return Thread(**message_thread.dict())
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def create_thread(
|
|
self,
|
|
metadata: Optional[dict],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
client: Optional[AsyncOpenAI],
|
|
acreate_thread: Literal[True],
|
|
) -> Coroutine[None, None, Thread]:
|
|
...
|
|
|
|
@overload
|
|
def create_thread(
|
|
self,
|
|
metadata: Optional[dict],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
client: Optional[OpenAI],
|
|
acreate_thread: Optional[Literal[False]],
|
|
) -> Thread:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def create_thread(
|
|
self,
|
|
metadata: Optional[dict],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
client=None,
|
|
acreate_thread=None,
|
|
):
|
|
"""
|
|
Here's an example:
|
|
```
|
|
from litellm.llms.openai import OpenAIAssistantsAPI, MessageData
|
|
|
|
# create thread
|
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
|
|
openai_api.create_thread(messages=[message])
|
|
```
|
|
"""
|
|
if acreate_thread is not None and acreate_thread == True:
|
|
return self.async_create_thread(
|
|
metadata=metadata,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
messages=messages,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
data = {}
|
|
if messages is not None:
|
|
data["messages"] = messages # type: ignore
|
|
if metadata is not None:
|
|
data["metadata"] = metadata # type: ignore
|
|
|
|
message_thread = openai_client.beta.threads.create(**data) # type: ignore
|
|
|
|
return Thread(**message_thread.dict())
|
|
|
|
async def async_get_thread(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
) -> Thread:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
|
|
|
|
return Thread(**response.dict())
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def get_thread(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
aget_thread: Literal[True],
|
|
) -> Coroutine[None, None, Thread]:
|
|
...
|
|
|
|
@overload
|
|
def get_thread(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[OpenAI],
|
|
aget_thread: Optional[Literal[False]],
|
|
) -> Thread:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def get_thread(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client=None,
|
|
aget_thread=None,
|
|
):
|
|
if aget_thread is not None and aget_thread == True:
|
|
return self.async_get_thread(
|
|
thread_id=thread_id,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
|
|
|
|
return Thread(**response.dict())
|
|
|
|
def delete_thread(self):
|
|
pass
|
|
|
|
### RUNS ###
|
|
|
|
async def arun_thread(
|
|
self,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[object],
|
|
model: Optional[str],
|
|
stream: Optional[bool],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client: Optional[AsyncOpenAI],
|
|
) -> Run:
|
|
openai_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata,
|
|
model=model,
|
|
tools=tools,
|
|
)
|
|
|
|
return response
|
|
|
|
def async_run_thread_stream(
|
|
self,
|
|
client: AsyncOpenAI,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[object],
|
|
model: Optional[str],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
event_handler: Optional[AssistantEventHandler],
|
|
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
|
data = {
|
|
"thread_id": thread_id,
|
|
"assistant_id": assistant_id,
|
|
"additional_instructions": additional_instructions,
|
|
"instructions": instructions,
|
|
"metadata": metadata,
|
|
"model": model,
|
|
"tools": tools,
|
|
}
|
|
if event_handler is not None:
|
|
data["event_handler"] = event_handler
|
|
return client.beta.threads.runs.stream(**data) # type: ignore
|
|
|
|
def run_thread_stream(
|
|
self,
|
|
client: OpenAI,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[object],
|
|
model: Optional[str],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
event_handler: Optional[AssistantEventHandler],
|
|
) -> AssistantStreamManager[AssistantEventHandler]:
|
|
data = {
|
|
"thread_id": thread_id,
|
|
"assistant_id": assistant_id,
|
|
"additional_instructions": additional_instructions,
|
|
"instructions": instructions,
|
|
"metadata": metadata,
|
|
"model": model,
|
|
"tools": tools,
|
|
}
|
|
if event_handler is not None:
|
|
data["event_handler"] = event_handler
|
|
return client.beta.threads.runs.stream(**data) # type: ignore
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def run_thread(
|
|
self,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[object],
|
|
model: Optional[str],
|
|
stream: Optional[bool],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client,
|
|
arun_thread: Literal[True],
|
|
event_handler: Optional[AssistantEventHandler],
|
|
) -> Coroutine[None, None, Run]:
|
|
...
|
|
|
|
@overload
|
|
def run_thread(
|
|
self,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[object],
|
|
model: Optional[str],
|
|
stream: Optional[bool],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client,
|
|
arun_thread: Optional[Literal[False]],
|
|
event_handler: Optional[AssistantEventHandler],
|
|
) -> Run:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def run_thread(
|
|
self,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[object],
|
|
model: Optional[str],
|
|
stream: Optional[bool],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
organization: Optional[str],
|
|
client=None,
|
|
arun_thread=None,
|
|
event_handler: Optional[AssistantEventHandler] = None,
|
|
):
|
|
if arun_thread is not None and arun_thread == True:
|
|
if stream is not None and stream == True:
|
|
_client = self.async_get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
return self.async_run_thread_stream(
|
|
client=_client,
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata,
|
|
model=model,
|
|
tools=tools,
|
|
event_handler=event_handler,
|
|
)
|
|
return self.arun_thread(
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata,
|
|
model=model,
|
|
stream=stream,
|
|
tools=tools,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
openai_client = self.get_openai_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
organization=organization,
|
|
client=client,
|
|
)
|
|
|
|
if stream is not None and stream == True:
|
|
return self.run_thread_stream(
|
|
client=openai_client,
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata,
|
|
model=model,
|
|
tools=tools,
|
|
event_handler=event_handler,
|
|
)
|
|
|
|
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata,
|
|
model=model,
|
|
tools=tools,
|
|
)
|
|
|
|
return response
|