Hi {user_name},
@@ -1274,7 +1342,7 @@ Model Info:
API requests will be rejected until either (a) you increase your monthly budget or (b) your monthly usage resets at the beginning of the next calendar month.
- If you have any questions, please send an email to {EMAIL_SUPPORT_CONTACT}
+ If you have any questions, please send an email to {email_support_contact}
Best,
The LiteLLM team
@@ -1384,7 +1452,9 @@ Model Info:
if response.status_code == 200:
pass
else:
- print("Error sending slack alert. Error=", response.text) # noqa
+ verbose_proxy_logger.debug(
+ "Error sending slack alert. Error=", response.text
+ )
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""Log deployment latency"""
@@ -1404,6 +1474,8 @@ Model Info:
final_value = float(
response_s.total_seconds() / completion_tokens
)
+ if isinstance(final_value, timedelta):
+ final_value = final_value.total_seconds()
await self.async_update_daily_reports(
DeploymentMetrics(
diff --git a/litellm/integrations/supabase.py b/litellm/integrations/supabase.py
index 4e6bf517f..7309342e4 100644
--- a/litellm/integrations/supabase.py
+++ b/litellm/integrations/supabase.py
@@ -110,6 +110,5 @@ class Supabase:
)
except:
- # traceback.print_exc()
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
pass
diff --git a/litellm/integrations/test_httpx.py b/litellm/integrations/test_httpx.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/litellm/integrations/traceloop.py b/litellm/integrations/traceloop.py
index bbdb9a1b0..e1c419c6f 100644
--- a/litellm/integrations/traceloop.py
+++ b/litellm/integrations/traceloop.py
@@ -1,114 +1,149 @@
+import traceback
+from litellm._logging import verbose_logger
+import litellm
+
+
class TraceloopLogger:
def __init__(self):
- from traceloop.sdk.tracing.tracing import TracerWrapper
- from traceloop.sdk import Traceloop
+ try:
+ from traceloop.sdk.tracing.tracing import TracerWrapper
+ from traceloop.sdk import Traceloop
+ from traceloop.sdk.instruments import Instruments
+ from opentelemetry.sdk.trace.export import ConsoleSpanExporter
+ except ModuleNotFoundError as e:
+ verbose_logger.error(
+ f"Traceloop not installed, try running 'pip install traceloop-sdk' to fix this error: {e}\n{traceback.format_exc()}"
+ )
- Traceloop.init(app_name="Litellm-Server", disable_batch=True)
+ Traceloop.init(
+ app_name="Litellm-Server",
+ disable_batch=True,
+ )
self.tracer_wrapper = TracerWrapper()
- def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
- from opentelemetry.trace import SpanKind
+ def log_event(
+ self,
+ kwargs,
+ response_obj,
+ start_time,
+ end_time,
+ user_id,
+ print_verbose,
+ level="DEFAULT",
+ status_message=None,
+ ):
+ from opentelemetry import trace
+ from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.semconv.ai import SpanAttributes
try:
+ print_verbose(
+ f"Traceloop Logging - Enters logging function for model {kwargs}"
+ )
+
tracer = self.tracer_wrapper.get_tracer()
- model = kwargs.get("model")
-
- # LiteLLM uses the standard OpenAI library, so it's already handled by Traceloop SDK
- if kwargs.get("litellm_params").get("custom_llm_provider") == "openai":
- return
-
optional_params = kwargs.get("optional_params", {})
- with tracer.start_as_current_span(
- "litellm.completion",
- kind=SpanKind.CLIENT,
- ) as span:
- if span.is_recording():
+ start_time = int(start_time.timestamp())
+ end_time = int(end_time.timestamp())
+ span = tracer.start_span(
+ "litellm.completion", kind=SpanKind.CLIENT, start_time=start_time
+ )
+
+ if span.is_recording():
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_MODEL, kwargs.get("model")
+ )
+ if "stop" in optional_params:
span.set_attribute(
- SpanAttributes.LLM_REQUEST_MODEL, kwargs.get("model")
+ SpanAttributes.LLM_CHAT_STOP_SEQUENCES,
+ optional_params.get("stop"),
)
- if "stop" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_CHAT_STOP_SEQUENCES,
- optional_params.get("stop"),
- )
- if "frequency_penalty" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_FREQUENCY_PENALTY,
- optional_params.get("frequency_penalty"),
- )
- if "presence_penalty" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_PRESENCE_PENALTY,
- optional_params.get("presence_penalty"),
- )
- if "top_p" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_TOP_P, optional_params.get("top_p")
- )
- if "tools" in optional_params or "functions" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_REQUEST_FUNCTIONS,
- optional_params.get(
- "tools", optional_params.get("functions")
- ),
- )
- if "user" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_USER, optional_params.get("user")
- )
- if "max_tokens" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_REQUEST_MAX_TOKENS,
- kwargs.get("max_tokens"),
- )
- if "temperature" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_TEMPERATURE, kwargs.get("temperature")
- )
-
- for idx, prompt in enumerate(kwargs.get("messages")):
- span.set_attribute(
- f"{SpanAttributes.LLM_PROMPTS}.{idx}.role",
- prompt.get("role"),
- )
- span.set_attribute(
- f"{SpanAttributes.LLM_PROMPTS}.{idx}.content",
- prompt.get("content"),
- )
-
+ if "frequency_penalty" in optional_params:
span.set_attribute(
- SpanAttributes.LLM_RESPONSE_MODEL, response_obj.get("model")
+ SpanAttributes.LLM_FREQUENCY_PENALTY,
+ optional_params.get("frequency_penalty"),
+ )
+ if "presence_penalty" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_PRESENCE_PENALTY,
+ optional_params.get("presence_penalty"),
+ )
+ if "top_p" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_TOP_P, optional_params.get("top_p")
+ )
+ if "tools" in optional_params or "functions" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_FUNCTIONS,
+ optional_params.get("tools", optional_params.get("functions")),
+ )
+ if "user" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_USER, optional_params.get("user")
+ )
+ if "max_tokens" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_MAX_TOKENS,
+ kwargs.get("max_tokens"),
+ )
+ if "temperature" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_TEMPERATURE,
+ kwargs.get("temperature"),
)
- usage = response_obj.get("usage")
- if usage:
- span.set_attribute(
- SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
- usage.get("total_tokens"),
- )
- span.set_attribute(
- SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
- usage.get("completion_tokens"),
- )
- span.set_attribute(
- SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
- usage.get("prompt_tokens"),
- )
- for idx, choice in enumerate(response_obj.get("choices")):
- span.set_attribute(
- f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.finish_reason",
- choice.get("finish_reason"),
- )
- span.set_attribute(
- f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.role",
- choice.get("message").get("role"),
- )
- span.set_attribute(
- f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.content",
- choice.get("message").get("content"),
- )
+ for idx, prompt in enumerate(kwargs.get("messages")):
+ span.set_attribute(
+ f"{SpanAttributes.LLM_PROMPTS}.{idx}.role",
+ prompt.get("role"),
+ )
+ span.set_attribute(
+ f"{SpanAttributes.LLM_PROMPTS}.{idx}.content",
+ prompt.get("content"),
+ )
+
+ span.set_attribute(
+ SpanAttributes.LLM_RESPONSE_MODEL, response_obj.get("model")
+ )
+ usage = response_obj.get("usage")
+ if usage:
+ span.set_attribute(
+ SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
+ usage.get("total_tokens"),
+ )
+ span.set_attribute(
+ SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
+ usage.get("completion_tokens"),
+ )
+ span.set_attribute(
+ SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
+ usage.get("prompt_tokens"),
+ )
+
+ for idx, choice in enumerate(response_obj.get("choices")):
+ span.set_attribute(
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.finish_reason",
+ choice.get("finish_reason"),
+ )
+ span.set_attribute(
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.role",
+ choice.get("message").get("role"),
+ )
+ span.set_attribute(
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.content",
+ choice.get("message").get("content"),
+ )
+
+ if (
+ level == "ERROR"
+ and status_message is not None
+ and isinstance(status_message, str)
+ ):
+ span.record_exception(Exception(status_message))
+ span.set_status(Status(StatusCode.ERROR, status_message))
+
+ span.end(end_time)
except Exception as e:
print_verbose(f"Traceloop Layer Error - {e}")
diff --git a/litellm/integrations/weights_biases.py b/litellm/integrations/weights_biases.py
index a56233b22..1ac535c4f 100644
--- a/litellm/integrations/weights_biases.py
+++ b/litellm/integrations/weights_biases.py
@@ -217,6 +217,5 @@ class WeightsBiasesLogger:
f"W&B Logging Logging - final response object: {response_obj}"
)
except:
- # traceback.print_exc()
print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}")
pass
diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py
index 1ca048523..8e469a8f4 100644
--- a/litellm/llms/anthropic.py
+++ b/litellm/llms/anthropic.py
@@ -3,6 +3,7 @@ import json
from enum import Enum
import requests, copy # type: ignore
import time
+from functools import partial
from typing import Callable, Optional, List, Union
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
@@ -160,6 +161,36 @@ def validate_environment(api_key, user_headers):
return headers
+async def make_call(
+ client: Optional[AsyncHTTPHandler],
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+):
+ if client is None:
+ client = AsyncHTTPHandler() # Create a new client if none provided
+
+ response = await client.post(api_base, headers=headers, data=data, stream=True)
+
+ if response.status_code != 200:
+ raise AnthropicError(status_code=response.status_code, message=response.text)
+
+ completion_stream = response.aiter_lines()
+
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=completion_stream, # Pass the completion stream for logging
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
@@ -379,23 +410,34 @@ class AnthropicChatCompletion(BaseLLM):
logger_fn=None,
headers={},
):
- self.async_handler = AsyncHTTPHandler(
- timeout=httpx.Timeout(timeout=600.0, connect=5.0)
- )
data["stream"] = True
- response = await self.async_handler.post(
- api_base, headers=headers, data=json.dumps(data), stream=True
- )
+ # async_handler = AsyncHTTPHandler(
+ # timeout=httpx.Timeout(timeout=600.0, connect=20.0)
+ # )
- if response.status_code != 200:
- raise AnthropicError(
- status_code=response.status_code, message=response.text
- )
+ # response = await async_handler.post(
+ # api_base, headers=headers, json=data, stream=True
+ # )
- completion_stream = response.aiter_lines()
+ # if response.status_code != 200:
+ # raise AnthropicError(
+ # status_code=response.status_code, message=response.text
+ # )
+
+ # completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
- completion_stream=completion_stream,
+ completion_stream=None,
+ make_call=partial(
+ make_call,
+ client=None,
+ api_base=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ ),
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
@@ -421,12 +463,10 @@ class AnthropicChatCompletion(BaseLLM):
logger_fn=None,
headers={},
) -> Union[ModelResponse, CustomStreamWrapper]:
- self.async_handler = AsyncHTTPHandler(
+ async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
- response = await self.async_handler.post(
- api_base, headers=headers, data=json.dumps(data)
- )
+ response = await async_handler.post(api_base, headers=headers, json=data)
if stream and _is_function_call:
return self.process_streaming_response(
model=model,
diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py
index 02fe4a08f..834fcbea9 100644
--- a/litellm/llms/azure.py
+++ b/litellm/llms/azure.py
@@ -1,4 +1,5 @@
-from typing import Optional, Union, Any, Literal
+from typing import Optional, Union, Any, Literal, Coroutine, Iterable
+from typing_extensions import overload
import types, requests
from .base import BaseLLM
from litellm.utils import (
@@ -9,6 +10,7 @@ from litellm.utils import (
convert_to_model_response_object,
TranscriptionResponse,
get_secret,
+ UnsupportedParamsError,
)
from typing import Callable, Optional, BinaryIO, List
from litellm import OpenAIConfig
@@ -18,6 +20,22 @@ from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTra
from openai import AzureOpenAI, AsyncAzureOpenAI
import uuid
import os
+from ..types.llms.openai import (
+ AsyncCursorPage,
+ AssistantToolParam,
+ SyncCursorPage,
+ Assistant,
+ MessageData,
+ OpenAIMessage,
+ OpenAICreateThreadParamsMessage,
+ Thread,
+ AssistantToolParam,
+ Run,
+ AssistantEventHandler,
+ AsyncAssistantEventHandler,
+ AsyncAssistantStreamManager,
+ AssistantStreamManager,
+)
class AzureOpenAIError(Exception):
@@ -45,9 +63,9 @@ class AzureOpenAIError(Exception):
) # Call the base class constructor with the parameters it needs
-class AzureOpenAIConfig(OpenAIConfig):
+class AzureOpenAIConfig:
"""
- Reference: https://platform.openai.com/docs/api-reference/chat/create
+ Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
@@ -85,18 +103,111 @@ class AzureOpenAIConfig(OpenAIConfig):
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
- super().__init__(
- frequency_penalty,
- function_call,
- functions,
- logit_bias,
- max_tokens,
- n,
- presence_penalty,
- stop,
- temperature,
- top_p,
- )
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_supported_openai_params(self):
+ return [
+ "temperature",
+ "n",
+ "stream",
+ "stop",
+ "max_tokens",
+ "tools",
+ "tool_choice",
+ "presence_penalty",
+ "frequency_penalty",
+ "logit_bias",
+ "user",
+ "function_call",
+ "functions",
+ "tools",
+ "tool_choice",
+ "top_p",
+ "logprobs",
+ "top_logprobs",
+ "response_format",
+ "seed",
+ "extra_headers",
+ ]
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ api_version: str, # Y-M-D-{optional}
+ drop_params,
+ ) -> dict:
+ supported_openai_params = self.get_supported_openai_params()
+
+ api_version_times = api_version.split("-")
+ api_version_year = api_version_times[0]
+ api_version_month = api_version_times[1]
+ api_version_day = api_version_times[2]
+ for param, value in non_default_params.items():
+ if param == "tool_choice":
+ """
+ This parameter requires API version 2023-12-01-preview or later
+
+ tool_choice='required' is not supported as of 2024-05-01-preview
+ """
+ ## check if api version supports this param ##
+ if (
+ api_version_year < "2023"
+ or (api_version_year == "2023" and api_version_month < "12")
+ or (
+ api_version_year == "2023"
+ and api_version_month == "12"
+ and api_version_day < "01"
+ )
+ ):
+ if litellm.drop_params == True or (
+ drop_params is not None and drop_params == True
+ ):
+ pass
+ else:
+ raise UnsupportedParamsError(
+ status_code=400,
+ message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""",
+ )
+ elif value == "required" and (
+ api_version_year == "2024" and api_version_month <= "05"
+ ): ## check if tool_choice value is supported ##
+ if litellm.drop_params == True or (
+ drop_params is not None and drop_params == True
+ ):
+ pass
+ else:
+ raise UnsupportedParamsError(
+ status_code=400,
+ message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions",
+ )
+ else:
+ optional_params["tool_choice"] = value
+ elif param in supported_openai_params:
+ optional_params[param] = value
+ return optional_params
def get_mapped_special_auth_params(self) -> dict:
return {"token": "azure_ad_token"}
@@ -114,6 +225,68 @@ class AzureOpenAIConfig(OpenAIConfig):
return ["europe", "sweden", "switzerland", "france", "uk"]
+class AzureOpenAIAssistantsAPIConfig:
+ """
+ Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
+ """
+
+ def __init__(
+ self,
+ ) -> None:
+ pass
+
+ def get_supported_openai_create_message_params(self):
+ return [
+ "role",
+ "content",
+ "attachments",
+ "metadata",
+ ]
+
+ def map_openai_params_create_message_params(
+ self, non_default_params: dict, optional_params: dict
+ ):
+ for param, value in non_default_params.items():
+ if param == "role":
+ optional_params["role"] = value
+ if param == "metadata":
+ optional_params["metadata"] = value
+ elif param == "content": # only string accepted
+ if isinstance(value, str):
+ optional_params["content"] = value
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="Azure only accepts content as a string.",
+ status_code=400,
+ )
+ elif (
+ param == "attachments"
+ ): # this is a v2 param. Azure currently supports the old 'file_id's param
+ file_ids: List[str] = []
+ if isinstance(value, list):
+ for item in value:
+ if "file_id" in item:
+ file_ids.append(item["file_id"])
+ else:
+ if litellm.drop_params == True:
+ pass
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="Azure doesn't support {}. To drop it from the call, set `litellm.drop_params = True.".format(
+ value
+ ),
+ status_code=400,
+ )
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="Invalid param. attachments should always be a list. Got={}, Expected=List. Raw value={}".format(
+ type(value), value
+ ),
+ status_code=400,
+ )
+ return optional_params
+
+
def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = {
# "api_version": api_version,
@@ -172,9 +345,7 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
possible_azure_ad_token = req_token.json().get("access_token", None)
if possible_azure_ad_token is None:
- raise AzureOpenAIError(
- status_code=422, message="Azure AD Token not returned"
- )
+ raise AzureOpenAIError(status_code=422, message="Azure AD Token not returned")
return possible_azure_ad_token
@@ -245,7 +416,9 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
- azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
+ azure_ad_token = get_azure_ad_token_from_oidc(
+ azure_ad_token
+ )
azure_client_params["azure_ad_token"] = azure_ad_token
@@ -1192,3 +1365,828 @@ class AzureChatCompletion(BaseLLM):
response["x-ms-region"] = completion.headers["x-ms-region"]
return response
+
+
+class AzureAssistantsAPI(BaseLLM):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def get_azure_client(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AzureOpenAI] = None,
+ ) -> AzureOpenAI:
+ received_args = locals()
+ if client is None:
+ data = {}
+ for k, v in received_args.items():
+ if k == "self" or k == "client":
+ pass
+ elif k == "api_base" and v is not None:
+ data["azure_endpoint"] = v
+ elif v is not None:
+ data[k] = v
+ azure_openai_client = AzureOpenAI(**data) # type: ignore
+ else:
+ azure_openai_client = client
+
+ return azure_openai_client
+
+ def async_get_azure_client(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI] = None,
+ ) -> AsyncAzureOpenAI:
+ received_args = locals()
+ if client is None:
+ data = {}
+ for k, v in received_args.items():
+ if k == "self" or k == "client":
+ pass
+ elif k == "api_base" and v is not None:
+ data["azure_endpoint"] = v
+ elif v is not None:
+ data[k] = v
+
+ azure_openai_client = AsyncAzureOpenAI(**data)
+ # azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
+ else:
+ azure_openai_client = client
+
+ return azure_openai_client
+
+ ### ASSISTANTS ###
+
+ async def async_get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ ) -> AsyncCursorPage[Assistant]:
+ azure_openai_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = await azure_openai_client.beta.assistants.list()
+
+ return response
+
+ # fmt: off
+
+ @overload
+ def get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ aget_assistants: Literal[True],
+ ) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
+ ...
+
+ @overload
+ def get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AzureOpenAI],
+ aget_assistants: Optional[Literal[False]],
+ ) -> SyncCursorPage[Assistant]:
+ ...
+
+ # fmt: on
+
+ def get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client=None,
+ aget_assistants=None,
+ ):
+ if aget_assistants is not None and aget_assistants == True:
+ return self.async_get_assistants(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+ azure_openai_client = self.get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ api_version=api_version,
+ )
+
+ response = azure_openai_client.beta.assistants.list()
+
+ return response
+
+ ### MESSAGES ###
+
+ async def a_add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI] = None,
+ ) -> OpenAIMessage:
+ openai_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
+ thread_id, **message_data # type: ignore
+ )
+
+ response_obj: Optional[OpenAIMessage] = None
+ if getattr(thread_message, "status", None) is None:
+ thread_message.status = "completed"
+ response_obj = OpenAIMessage(**thread_message.dict())
+ else:
+ response_obj = OpenAIMessage(**thread_message.dict())
+ return response_obj
+
+ # fmt: off
+
+ @overload
+ def add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ a_add_message: Literal[True],
+ ) -> Coroutine[None, None, OpenAIMessage]:
+ ...
+
+ @overload
+ def add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AzureOpenAI],
+ a_add_message: Optional[Literal[False]],
+ ) -> OpenAIMessage:
+ ...
+
+ # fmt: on
+
+ def add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client=None,
+ a_add_message: Optional[bool] = None,
+ ):
+ if a_add_message is not None and a_add_message == True:
+ return self.a_add_message(
+ thread_id=thread_id,
+ message_data=message_data,
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+ openai_client = self.get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
+ thread_id, **message_data # type: ignore
+ )
+
+ response_obj: Optional[OpenAIMessage] = None
+ if getattr(thread_message, "status", None) is None:
+ thread_message.status = "completed"
+ response_obj = OpenAIMessage(**thread_message.dict())
+ else:
+ response_obj = OpenAIMessage(**thread_message.dict())
+ return response_obj
+
+ async def async_get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI] = None,
+ ) -> AsyncCursorPage[OpenAIMessage]:
+ openai_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
+
+ return response
+
+ # fmt: off
+
+ @overload
+ def get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ aget_messages: Literal[True],
+ ) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
+ ...
+
+ @overload
+ def get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AzureOpenAI],
+ aget_messages: Optional[Literal[False]],
+ ) -> SyncCursorPage[OpenAIMessage]:
+ ...
+
+ # fmt: on
+
+ def get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client=None,
+ aget_messages=None,
+ ):
+ if aget_messages is not None and aget_messages == True:
+ return self.async_get_messages(
+ thread_id=thread_id,
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+ openai_client = self.get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = openai_client.beta.threads.messages.list(thread_id=thread_id)
+
+ return response
+
+ ### THREADS ###
+
+ async def async_create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ ) -> Thread:
+ openai_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ data = {}
+ if messages is not None:
+ data["messages"] = messages # type: ignore
+ if metadata is not None:
+ data["metadata"] = metadata # type: ignore
+
+ message_thread = await openai_client.beta.threads.create(**data) # type: ignore
+
+ return Thread(**message_thread.dict())
+
+ # fmt: off
+
+ @overload
+ def create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ client: Optional[AsyncAzureOpenAI],
+ acreate_thread: Literal[True],
+ ) -> Coroutine[None, None, Thread]:
+ ...
+
+ @overload
+ def create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ client: Optional[AzureOpenAI],
+ acreate_thread: Optional[Literal[False]],
+ ) -> Thread:
+ ...
+
+ # fmt: on
+
+ def create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ client=None,
+ acreate_thread=None,
+ ):
+ """
+ Here's an example:
+ ```
+ from litellm.llms.openai import OpenAIAssistantsAPI, MessageData
+
+ # create thread
+ message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
+ openai_api.create_thread(messages=[message])
+ ```
+ """
+ if acreate_thread is not None and acreate_thread == True:
+ return self.async_create_thread(
+ metadata=metadata,
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ messages=messages,
+ )
+ azure_openai_client = self.get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ data = {}
+ if messages is not None:
+ data["messages"] = messages # type: ignore
+ if metadata is not None:
+ data["metadata"] = metadata # type: ignore
+
+ message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore
+
+ return Thread(**message_thread.dict())
+
+ async def async_get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ ) -> Thread:
+ openai_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
+
+ return Thread(**response.dict())
+
+ # fmt: off
+
+ @overload
+ def get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ aget_thread: Literal[True],
+ ) -> Coroutine[None, None, Thread]:
+ ...
+
+ @overload
+ def get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AzureOpenAI],
+ aget_thread: Optional[Literal[False]],
+ ) -> Thread:
+ ...
+
+ # fmt: on
+
+ def get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client=None,
+ aget_thread=None,
+ ):
+ if aget_thread is not None and aget_thread == True:
+ return self.async_get_thread(
+ thread_id=thread_id,
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+ openai_client = self.get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = openai_client.beta.threads.retrieve(thread_id=thread_id)
+
+ return Thread(**response.dict())
+
+ # def delete_thread(self):
+ # pass
+
+ ### RUNS ###
+
+ async def arun_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ ) -> Run:
+ openai_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ client=client,
+ )
+
+ response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ )
+
+ return response
+
+ def async_run_thread_stream(
+ self,
+ client: AsyncAzureOpenAI,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ tools: Optional[Iterable[AssistantToolParam]],
+ event_handler: Optional[AssistantEventHandler],
+ ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
+ data = {
+ "thread_id": thread_id,
+ "assistant_id": assistant_id,
+ "additional_instructions": additional_instructions,
+ "instructions": instructions,
+ "metadata": metadata,
+ "model": model,
+ "tools": tools,
+ }
+ if event_handler is not None:
+ data["event_handler"] = event_handler
+ return client.beta.threads.runs.stream(**data) # type: ignore
+
+ def run_thread_stream(
+ self,
+ client: AzureOpenAI,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ tools: Optional[Iterable[AssistantToolParam]],
+ event_handler: Optional[AssistantEventHandler],
+ ) -> AssistantStreamManager[AssistantEventHandler]:
+ data = {
+ "thread_id": thread_id,
+ "assistant_id": assistant_id,
+ "additional_instructions": additional_instructions,
+ "instructions": instructions,
+ "metadata": metadata,
+ "model": model,
+ "tools": tools,
+ }
+ if event_handler is not None:
+ data["event_handler"] = event_handler
+ return client.beta.threads.runs.stream(**data) # type: ignore
+
+ # fmt: off
+
+ @overload
+ def run_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ arun_thread: Literal[True],
+ ) -> Coroutine[None, None, Run]:
+ ...
+
+ @overload
+ def run_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AzureOpenAI],
+ arun_thread: Optional[Literal[False]],
+ ) -> Run:
+ ...
+
+ # fmt: on
+
+ def run_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client=None,
+ arun_thread=None,
+ event_handler: Optional[AssistantEventHandler] = None,
+ ):
+ if arun_thread is not None and arun_thread == True:
+ if stream is not None and stream == True:
+ azure_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+ return self.async_run_thread_stream(
+ client=azure_client,
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ event_handler=event_handler,
+ )
+ return self.arun_thread(
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ stream=stream,
+ tools=tools,
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+ openai_client = self.get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ if stream is not None and stream == True:
+ return self.run_thread_stream(
+ client=openai_client,
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ event_handler=event_handler,
+ )
+
+ response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ )
+
+ return response
diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py
index 337055dc2..b011d9512 100644
--- a/litellm/llms/bedrock_httpx.py
+++ b/litellm/llms/bedrock_httpx.py
@@ -1,7 +1,7 @@
# What is this?
## Initial implementation of calling bedrock via httpx client (allows for async calls).
## V1 - covers cohere + anthropic claude-3 support
-
+from functools import partial
import os, types
import json
from enum import Enum
@@ -38,6 +38,8 @@ from .prompt_templates.factory import (
extract_between_tags,
parse_xml_params,
contains_tag,
+ _bedrock_converse_messages_pt,
+ _bedrock_tools_pt,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
@@ -45,6 +47,12 @@ import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
from litellm.types.llms.bedrock import *
import urllib.parse
+from litellm.types.llms.openai import (
+ ChatCompletionResponseMessage,
+ ChatCompletionToolCallChunk,
+ ChatCompletionToolCallFunctionChunk,
+ ChatCompletionDeltaChunk,
+)
class AmazonCohereChatConfig:
@@ -118,6 +126,8 @@ class AmazonCohereChatConfig:
"presence_penalty",
"seed",
"stop",
+ "tools",
+ "tool_choice",
]
def map_openai_params(
@@ -145,6 +155,68 @@ class AmazonCohereChatConfig:
return optional_params
+async def make_call(
+ client: Optional[AsyncHTTPHandler],
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+):
+ if client is None:
+ client = AsyncHTTPHandler() # Create a new client if none provided
+
+ response = await client.post(api_base, headers=headers, data=data, stream=True)
+
+ if response.status_code != 200:
+ raise BedrockError(status_code=response.status_code, message=response.text)
+
+ decoder = AWSEventStreamDecoder(model=model)
+ completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
+
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response="first stream response received",
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
+def make_sync_call(
+ client: Optional[HTTPHandler],
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+):
+ if client is None:
+ client = HTTPHandler() # Create a new client if none provided
+
+ response = client.post(api_base, headers=headers, data=data, stream=True)
+
+ if response.status_code != 200:
+ raise BedrockError(status_code=response.status_code, message=response.read())
+
+ decoder = AWSEventStreamDecoder(model=model)
+ completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
+
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response="first stream response received",
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
class BedrockLLM(BaseLLM):
"""
Example call
@@ -217,6 +289,7 @@ class BedrockLLM(BaseLLM):
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
+ aws_web_identity_token: Optional[str] = None,
):
"""
Return a boto3.Credentials object
@@ -231,6 +304,7 @@ class BedrockLLM(BaseLLM):
aws_session_name,
aws_profile_name,
aws_role_name,
+ aws_web_identity_token,
]
# Iterate over parameters and update if needed
@@ -247,10 +321,43 @@ class BedrockLLM(BaseLLM):
aws_session_name,
aws_profile_name,
aws_role_name,
+ aws_web_identity_token,
) = params_to_check
### CHECK STS ###
- if aws_role_name is not None and aws_session_name is not None:
+ if (
+ aws_web_identity_token is not None
+ and aws_role_name is not None
+ and aws_session_name is not None
+ ):
+ oidc_token = get_secret(aws_web_identity_token)
+
+ if oidc_token is None:
+ raise BedrockError(
+ message="OIDC token could not be retrieved from secret manager.",
+ status_code=401,
+ )
+
+ sts_client = boto3.client("sts")
+
+ # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
+ # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
+ sts_response = sts_client.assume_role_with_web_identity(
+ RoleArn=aws_role_name,
+ RoleSessionName=aws_session_name,
+ WebIdentityToken=oidc_token,
+ DurationSeconds=3600,
+ )
+
+ session = boto3.Session(
+ aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
+ aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
+ aws_session_token=sts_response["Credentials"]["SessionToken"],
+ region_name=aws_region_name,
+ )
+
+ return session.get_credentials()
+ elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
@@ -261,7 +368,16 @@ class BedrockLLM(BaseLLM):
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
- return sts_response["Credentials"]
+ # Extract the credentials from the response and convert to Session Credentials
+ sts_credentials = sts_response["Credentials"]
+ from botocore.credentials import Credentials
+
+ credentials = Credentials(
+ access_key=sts_credentials["AccessKeyId"],
+ secret_key=sts_credentials["SecretAccessKey"],
+ token=sts_credentials["SessionToken"],
+ )
+ return credentials
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
@@ -582,6 +698,7 @@ class BedrockLLM(BaseLLM):
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
+ aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
### SET REGION NAME ###
if aws_region_name is None:
@@ -609,6 +726,7 @@ class BedrockLLM(BaseLLM):
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
+ aws_web_identity_token=aws_web_identity_token,
)
### SET RUNTIME ENDPOINT ###
@@ -923,16 +1041,16 @@ class BedrockLLM(BaseLLM):
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
- self.client = AsyncHTTPHandler(**_params) # type: ignore
+ client = AsyncHTTPHandler(**_params) # type: ignore
else:
- self.client = client # type: ignore
+ client = client # type: ignore
try:
- response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
+ response = await client.post(api_base, headers=headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
- raise BedrockError(status_code=error_code, message=response.text)
+ raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e:
raise BedrockError(status_code=408, message="Timeout error occurred.")
@@ -968,43 +1086,760 @@ class BedrockLLM(BaseLLM):
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper:
+ # The call is not made here; instead, we prepare the necessary objects for the stream.
+
+ streaming_response = CustomStreamWrapper(
+ completion_stream=None,
+ make_call=partial(
+ make_call,
+ client=client,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ ),
+ model=model,
+ custom_llm_provider="bedrock",
+ logging_obj=logging_obj,
+ )
+ return streaming_response
+
+ def embedding(self, *args, **kwargs):
+ return super().embedding(*args, **kwargs)
+
+
+class AmazonConverseConfig:
+ """
+ Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
+ #2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
+ """
+
+ maxTokens: Optional[int]
+ stopSequences: Optional[List[str]]
+ temperature: Optional[int]
+ topP: Optional[int]
+
+ def __init__(
+ self,
+ maxTokens: Optional[int] = None,
+ stopSequences: Optional[List[str]] = None,
+ temperature: Optional[int] = None,
+ topP: Optional[int] = None,
+ ) -> None:
+ locals_ = locals()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_supported_openai_params(self, model: str) -> List[str]:
+ supported_params = [
+ "max_tokens",
+ "stream",
+ "stream_options",
+ "stop",
+ "temperature",
+ "top_p",
+ "extra_headers",
+ ]
+
+ if (
+ model.startswith("anthropic")
+ or model.startswith("mistral")
+ or model.startswith("cohere")
+ ):
+ supported_params.append("tools")
+
+ if model.startswith("anthropic") or model.startswith("mistral"):
+ # only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
+ supported_params.append("tool_choice")
+
+ return supported_params
+
+ def map_tool_choice_values(
+ self, model: str, tool_choice: Union[str, dict], drop_params: bool
+ ) -> Optional[ToolChoiceValuesBlock]:
+ if tool_choice == "none":
+ if litellm.drop_params is True or drop_params is True:
+ return None
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
+ tool_choice
+ ),
+ status_code=400,
+ )
+ elif tool_choice == "required":
+ return ToolChoiceValuesBlock(any={})
+ elif tool_choice == "auto":
+ return ToolChoiceValuesBlock(auto={})
+ elif isinstance(tool_choice, dict):
+ # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
+ specific_tool = SpecificToolChoiceBlock(
+ name=tool_choice.get("function", {}).get("name", "")
+ )
+ return ToolChoiceValuesBlock(tool=specific_tool)
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
+ tool_choice
+ ),
+ status_code=400,
+ )
+
+ def get_supported_image_types(self) -> List[str]:
+ return ["png", "jpeg", "gif", "webp"]
+
+ def map_openai_params(
+ self,
+ model: str,
+ non_default_params: dict,
+ optional_params: dict,
+ drop_params: bool,
+ ) -> dict:
+ for param, value in non_default_params.items():
+ if param == "max_tokens":
+ optional_params["maxTokens"] = value
+ if param == "stream":
+ optional_params["stream"] = value
+ if param == "stop":
+ if isinstance(value, str):
+ value = [value]
+ optional_params["stop_sequences"] = value
+ if param == "temperature":
+ optional_params["temperature"] = value
+ if param == "top_p":
+ optional_params["topP"] = value
+ if param == "tools":
+ optional_params["tools"] = value
+ if param == "tool_choice":
+ _tool_choice_value = self.map_tool_choice_values(
+ model=model, tool_choice=value, drop_params=drop_params # type: ignore
+ )
+ if _tool_choice_value is not None:
+ optional_params["tool_choice"] = _tool_choice_value
+ return optional_params
+
+
+class BedrockConverseLLM(BaseLLM):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def process_response(
+ self,
+ model: str,
+ response: Union[requests.Response, httpx.Response],
+ model_response: ModelResponse,
+ stream: bool,
+ logging_obj: Logging,
+ optional_params: dict,
+ api_key: str,
+ data: Union[dict, str],
+ messages: List,
+ print_verbose,
+ encoding,
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key=api_key,
+ original_response=response.text,
+ additional_args={"complete_input_dict": data},
+ )
+ print_verbose(f"raw model_response: {response.text}")
+
+ ## RESPONSE OBJECT
+ try:
+ completion_response = ConverseResponseBlock(**response.json()) # type: ignore
+ except Exception as e:
+ raise BedrockError(
+ message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
+ response.text, str(e)
+ ),
+ status_code=422,
+ )
+
+ """
+ Bedrock Response Object has optional message block
+
+ completion_response["output"].get("message", None)
+
+ A message block looks like this (Example 1):
+ "output": {
+ "message": {
+ "role": "assistant",
+ "content": [
+ {
+ "text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
+ }
+ ]
+ }
+ },
+ (Example 2):
+ "output": {
+ "message": {
+ "role": "assistant",
+ "content": [
+ {
+ "toolUse": {
+ "toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
+ "name": "top_song",
+ "input": {
+ "sign": "WZPZ"
+ }
+ }
+ }
+ ]
+ }
+ }
+
+ """
+ message: Optional[MessageBlock] = completion_response["output"]["message"]
+ chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
+ content_str = ""
+ tools: List[ChatCompletionToolCallChunk] = []
+ if message is not None:
+ for content in message["content"]:
+ """
+ - Content is either a tool response or text
+ """
+ if "text" in content:
+ content_str += content["text"]
+ if "toolUse" in content:
+ _function_chunk = ChatCompletionToolCallFunctionChunk(
+ name=content["toolUse"]["name"],
+ arguments=json.dumps(content["toolUse"]["input"]),
+ )
+ _tool_response_chunk = ChatCompletionToolCallChunk(
+ id=content["toolUse"]["toolUseId"],
+ type="function",
+ function=_function_chunk,
+ )
+ tools.append(_tool_response_chunk)
+ chat_completion_message["content"] = content_str
+ chat_completion_message["tool_calls"] = tools
+
+ ## CALCULATING USAGE - bedrock returns usage in the headers
+ input_tokens = completion_response["usage"]["inputTokens"]
+ output_tokens = completion_response["usage"]["outputTokens"]
+ total_tokens = completion_response["usage"]["totalTokens"]
+
+ model_response.choices = [
+ litellm.Choices(
+ finish_reason=map_finish_reason(completion_response["stopReason"]),
+ index=0,
+ message=litellm.Message(**chat_completion_message),
+ )
+ ]
+ model_response["created"] = int(time.time())
+ model_response["model"] = model
+ usage = Usage(
+ prompt_tokens=input_tokens,
+ completion_tokens=output_tokens,
+ total_tokens=total_tokens,
+ )
+ setattr(model_response, "usage", usage)
+
+ return model_response
+
+ def encode_model_id(self, model_id: str) -> str:
+ """
+ Double encode the model ID to ensure it matches the expected double-encoded format.
+ Args:
+ model_id (str): The model ID to encode.
+ Returns:
+ str: The double-encoded model ID.
+ """
+ return urllib.parse.quote(model_id, safe="")
+
+ def get_credentials(
+ self,
+ aws_access_key_id: Optional[str] = None,
+ aws_secret_access_key: Optional[str] = None,
+ aws_region_name: Optional[str] = None,
+ aws_session_name: Optional[str] = None,
+ aws_profile_name: Optional[str] = None,
+ aws_role_name: Optional[str] = None,
+ aws_web_identity_token: Optional[str] = None,
+ ):
+ """
+ Return a boto3.Credentials object
+ """
+ import boto3
+
+ ## CHECK IS 'os.environ/' passed in
+ params_to_check: List[Optional[str]] = [
+ aws_access_key_id,
+ aws_secret_access_key,
+ aws_region_name,
+ aws_session_name,
+ aws_profile_name,
+ aws_role_name,
+ aws_web_identity_token,
+ ]
+
+ # Iterate over parameters and update if needed
+ for i, param in enumerate(params_to_check):
+ if param and param.startswith("os.environ/"):
+ _v = get_secret(param)
+ if _v is not None and isinstance(_v, str):
+ params_to_check[i] = _v
+ # Assign updated values back to parameters
+ (
+ aws_access_key_id,
+ aws_secret_access_key,
+ aws_region_name,
+ aws_session_name,
+ aws_profile_name,
+ aws_role_name,
+ aws_web_identity_token,
+ ) = params_to_check
+
+ ### CHECK STS ###
+ if (
+ aws_web_identity_token is not None
+ and aws_role_name is not None
+ and aws_session_name is not None
+ ):
+ oidc_token = get_secret(aws_web_identity_token)
+
+ if oidc_token is None:
+ raise BedrockError(
+ message="OIDC token could not be retrieved from secret manager.",
+ status_code=401,
+ )
+
+ sts_client = boto3.client("sts")
+
+ # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
+ # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
+ sts_response = sts_client.assume_role_with_web_identity(
+ RoleArn=aws_role_name,
+ RoleSessionName=aws_session_name,
+ WebIdentityToken=oidc_token,
+ DurationSeconds=3600,
+ )
+
+ session = boto3.Session(
+ aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
+ aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
+ aws_session_token=sts_response["Credentials"]["SessionToken"],
+ region_name=aws_region_name,
+ )
+
+ return session.get_credentials()
+ elif aws_role_name is not None and aws_session_name is not None:
+ sts_client = boto3.client(
+ "sts",
+ aws_access_key_id=aws_access_key_id, # [OPTIONAL]
+ aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
+ )
+
+ sts_response = sts_client.assume_role(
+ RoleArn=aws_role_name, RoleSessionName=aws_session_name
+ )
+
+ # Extract the credentials from the response and convert to Session Credentials
+ sts_credentials = sts_response["Credentials"]
+ from botocore.credentials import Credentials
+
+ credentials = Credentials(
+ access_key=sts_credentials["AccessKeyId"],
+ secret_key=sts_credentials["SecretAccessKey"],
+ token=sts_credentials["SessionToken"],
+ )
+ return credentials
+ elif aws_profile_name is not None: ### CHECK SESSION ###
+ # uses auth values from AWS profile usually stored in ~/.aws/credentials
+ client = boto3.Session(profile_name=aws_profile_name)
+
+ return client.get_credentials()
+ else:
+ session = boto3.Session(
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=aws_region_name,
+ )
+
+ return session.get_credentials()
+
+ async def async_streaming(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ data: str,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ encoding,
+ logging_obj,
+ stream,
+ optional_params: dict,
+ litellm_params=None,
+ logger_fn=None,
+ headers={},
+ client: Optional[AsyncHTTPHandler] = None,
+ ) -> CustomStreamWrapper:
+ streaming_response = CustomStreamWrapper(
+ completion_stream=None,
+ make_call=partial(
+ make_call,
+ client=client,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ ),
+ model=model,
+ custom_llm_provider="bedrock",
+ logging_obj=logging_obj,
+ )
+ return streaming_response
+
+ async def async_completion(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ data: str,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ encoding,
+ logging_obj,
+ stream,
+ optional_params: dict,
+ litellm_params=None,
+ logger_fn=None,
+ headers={},
+ client: Optional[AsyncHTTPHandler] = None,
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
- self.client = AsyncHTTPHandler(**_params) # type: ignore
+ client = AsyncHTTPHandler(**_params) # type: ignore
else:
- self.client = client # type: ignore
+ client = client # type: ignore
- response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore
+ try:
+ response = await client.post(api_base, headers=headers, data=data) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise BedrockError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException as e:
+ raise BedrockError(status_code=408, message="Timeout error occurred.")
- if response.status_code != 200:
- raise BedrockError(status_code=response.status_code, message=response.text)
-
- decoder = AWSEventStreamDecoder(model=model)
-
- completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
- streaming_response = CustomStreamWrapper(
- completion_stream=completion_stream,
+ return self.process_response(
model=model,
- custom_llm_provider="bedrock",
+ response=response,
+ model_response=model_response,
+ stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj,
+ api_key="",
+ data=data,
+ messages=messages,
+ print_verbose=print_verbose,
+ optional_params=optional_params,
+ encoding=encoding,
)
+ def completion(
+ self,
+ model: str,
+ messages: list,
+ custom_prompt_dict: dict,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ logging_obj,
+ optional_params: dict,
+ acompletion: bool,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ litellm_params=None,
+ logger_fn=None,
+ extra_headers: Optional[dict] = None,
+ client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
+ ):
+ try:
+ import boto3
+
+ from botocore.auth import SigV4Auth
+ from botocore.awsrequest import AWSRequest
+ from botocore.credentials import Credentials
+ except ImportError:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+
+ ## SETUP ##
+ stream = optional_params.pop("stream", None)
+ modelId = optional_params.pop("model_id", None)
+ if modelId is not None:
+ modelId = self.encode_model_id(model_id=modelId)
+ else:
+ modelId = model
+
+ provider = model.split(".")[0]
+
+ ## CREDENTIALS ##
+ # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
+ aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
+ aws_access_key_id = optional_params.pop("aws_access_key_id", None)
+ aws_region_name = optional_params.pop("aws_region_name", None)
+ aws_role_name = optional_params.pop("aws_role_name", None)
+ aws_session_name = optional_params.pop("aws_session_name", None)
+ aws_profile_name = optional_params.pop("aws_profile_name", None)
+ aws_bedrock_runtime_endpoint = optional_params.pop(
+ "aws_bedrock_runtime_endpoint", None
+ ) # https://bedrock-runtime.{region_name}.amazonaws.com
+ aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
+
+ ### SET REGION NAME ###
+ if aws_region_name is None:
+ # check env #
+ litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
+
+ if litellm_aws_region_name is not None and isinstance(
+ litellm_aws_region_name, str
+ ):
+ aws_region_name = litellm_aws_region_name
+
+ standard_aws_region_name = get_secret("AWS_REGION", None)
+ if standard_aws_region_name is not None and isinstance(
+ standard_aws_region_name, str
+ ):
+ aws_region_name = standard_aws_region_name
+
+ if aws_region_name is None:
+ aws_region_name = "us-west-2"
+
+ credentials: Credentials = self.get_credentials(
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ aws_region_name=aws_region_name,
+ aws_session_name=aws_session_name,
+ aws_profile_name=aws_profile_name,
+ aws_role_name=aws_role_name,
+ aws_web_identity_token=aws_web_identity_token,
+ )
+
+ ### SET RUNTIME ENDPOINT ###
+ endpoint_url = ""
+ env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
+ if aws_bedrock_runtime_endpoint is not None and isinstance(
+ aws_bedrock_runtime_endpoint, str
+ ):
+ endpoint_url = aws_bedrock_runtime_endpoint
+ elif env_aws_bedrock_runtime_endpoint and isinstance(
+ env_aws_bedrock_runtime_endpoint, str
+ ):
+ endpoint_url = env_aws_bedrock_runtime_endpoint
+ else:
+ endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
+
+ if (stream is not None and stream is True) and provider != "ai21":
+ endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
+ else:
+ endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
+
+ sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
+
+ # Separate system prompt from rest of message
+ system_prompt_indices = []
+ system_content_blocks: List[SystemContentBlock] = []
+ for idx, message in enumerate(messages):
+ if message["role"] == "system":
+ _system_content_block = SystemContentBlock(text=message["content"])
+ system_content_blocks.append(_system_content_block)
+ system_prompt_indices.append(idx)
+ if len(system_prompt_indices) > 0:
+ for idx in reversed(system_prompt_indices):
+ messages.pop(idx)
+
+ inference_params = copy.deepcopy(optional_params)
+ additional_request_keys = []
+ additional_request_params = {}
+ supported_converse_params = AmazonConverseConfig.__annotations__.keys()
+ supported_tool_call_params = ["tools", "tool_choice"]
+ ## TRANSFORMATION ##
+ # send all model-specific params in 'additional_request_params'
+ for k, v in inference_params.items():
+ if (
+ k not in supported_converse_params
+ and k not in supported_tool_call_params
+ ):
+ additional_request_params[k] = v
+ additional_request_keys.append(k)
+ for key in additional_request_keys:
+ inference_params.pop(key, None)
+
+ bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
+ messages=messages
+ )
+ bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
+ inference_params.pop("tools", [])
+ )
+ bedrock_tool_config: Optional[ToolConfigBlock] = None
+ if len(bedrock_tools) > 0:
+ tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
+ "tool_choice", None
+ )
+ bedrock_tool_config = ToolConfigBlock(
+ tools=bedrock_tools,
+ )
+ if tool_choice_values is not None:
+ bedrock_tool_config["toolChoice"] = tool_choice_values
+
+ _data: RequestObject = {
+ "messages": bedrock_messages,
+ "additionalModelRequestFields": additional_request_params,
+ "system": system_content_blocks,
+ "inferenceConfig": InferenceConfig(**inference_params),
+ }
+ if bedrock_tool_config is not None:
+ _data["toolConfig"] = bedrock_tool_config
+ data = json.dumps(_data)
+ ## COMPLETION CALL
+
+ headers = {"Content-Type": "application/json"}
+ if extra_headers is not None:
+ headers = {"Content-Type": "application/json", **extra_headers}
+ request = AWSRequest(
+ method="POST", url=endpoint_url, data=data, headers=headers
+ )
+ sigv4.add_auth(request)
+ prepped = request.prepare()
+
## LOGGING
- logging_obj.post_call(
+ logging_obj.pre_call(
input=messages,
api_key="",
- original_response=streaming_response,
- additional_args={"complete_input_dict": data},
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": prepped.url,
+ "headers": prepped.headers,
+ },
)
- return streaming_response
+ ### ROUTING (ASYNC, STREAMING, SYNC)
+ if acompletion:
+ if isinstance(client, HTTPHandler):
+ client = None
+ if stream is True and provider != "ai21":
+ return self.async_streaming(
+ model=model,
+ messages=messages,
+ data=data,
+ api_base=prepped.url,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ stream=True,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ headers=prepped.headers,
+ timeout=timeout,
+ client=client,
+ ) # type: ignore
+ ### ASYNC COMPLETION
+ return self.async_completion(
+ model=model,
+ messages=messages,
+ data=data,
+ api_base=prepped.url,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ stream=stream, # type: ignore
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ headers=prepped.headers,
+ timeout=timeout,
+ client=client,
+ ) # type: ignore
- def embedding(self, *args, **kwargs):
- return super().embedding(*args, **kwargs)
+ if (stream is not None and stream is True) and provider != "ai21":
+
+ streaming_response = CustomStreamWrapper(
+ completion_stream=None,
+ make_call=partial(
+ make_sync_call,
+ client=None,
+ api_base=prepped.url,
+ headers=prepped.headers, # type: ignore
+ data=data,
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ ),
+ model=model,
+ custom_llm_provider="bedrock",
+ logging_obj=logging_obj,
+ )
+
+ return streaming_response
+ ### COMPLETION
+
+ if client is None or isinstance(client, AsyncHTTPHandler):
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ timeout = httpx.Timeout(timeout)
+ _params["timeout"] = timeout
+ client = HTTPHandler(**_params) # type: ignore
+ else:
+ client = client
+ try:
+ response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise BedrockError(status_code=error_code, message=response.text)
+ except httpx.TimeoutException:
+ raise BedrockError(status_code=408, message="Timeout error occurred.")
+
+ return self.process_response(
+ model=model,
+ response=response,
+ model_response=model_response,
+ stream=stream,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ api_key="",
+ data=data,
+ messages=messages,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ )
def get_response_stream_shape():
@@ -1024,6 +1859,61 @@ class AWSEventStreamDecoder:
self.model = model
self.parser = EventStreamJSONParser()
+ def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
+ try:
+ text = ""
+ tool_use: Optional[ChatCompletionToolCallChunk] = None
+ is_finished = False
+ finish_reason = ""
+ usage: Optional[ConverseTokenUsageBlock] = None
+
+ index = int(chunk_data.get("contentBlockIndex", 0))
+ if "start" in chunk_data:
+ start_obj = ContentBlockStartEvent(**chunk_data["start"])
+ if (
+ start_obj is not None
+ and "toolUse" in start_obj
+ and start_obj["toolUse"] is not None
+ ):
+ tool_use = {
+ "id": start_obj["toolUse"]["toolUseId"],
+ "type": "function",
+ "function": {
+ "name": start_obj["toolUse"]["name"],
+ "arguments": "",
+ },
+ }
+ elif "delta" in chunk_data:
+ delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
+ if "text" in delta_obj:
+ text = delta_obj["text"]
+ elif "toolUse" in delta_obj:
+ tool_use = {
+ "id": None,
+ "type": "function",
+ "function": {
+ "name": None,
+ "arguments": delta_obj["toolUse"]["input"],
+ },
+ }
+ elif "stopReason" in chunk_data:
+ finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
+ is_finished = True
+ elif "usage" in chunk_data:
+ usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
+
+ response = GenericStreamingChunk(
+ text=text,
+ tool_use=tool_use,
+ is_finished=is_finished,
+ finish_reason=finish_reason,
+ usage=usage,
+ index=index,
+ )
+ return response
+ except Exception as e:
+ raise Exception("Received streaming error - {}".format(str(e)))
+
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = ""
is_finished = False
@@ -1031,24 +1921,17 @@ class AWSEventStreamDecoder:
if "outputText" in chunk_data:
text = chunk_data["outputText"]
# ai21 mapping
- if "ai21" in self.model: # fake ai21 streaming
+ elif "ai21" in self.model: # fake ai21 streaming
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
is_finished = True
finish_reason = "stop"
######## bedrock.anthropic mappings ###############
- elif "completion" in chunk_data: # not claude-3
- text = chunk_data["completion"] # bedrock.anthropic
- stop_reason = chunk_data.get("stop_reason", None)
- if stop_reason != None:
- is_finished = True
- finish_reason = stop_reason
- elif "delta" in chunk_data:
- if chunk_data["delta"].get("text", None) is not None:
- text = chunk_data["delta"]["text"]
- stop_reason = chunk_data["delta"].get("stop_reason", None)
- if stop_reason != None:
- is_finished = True
- finish_reason = stop_reason
+ elif (
+ "contentBlockIndex" in chunk_data
+ or "stopReason" in chunk_data
+ or "metrics" in chunk_data
+ ):
+ return self.converse_chunk_parser(chunk_data=chunk_data)
######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data:
if (
@@ -1057,7 +1940,7 @@ class AWSEventStreamDecoder:
):
text = chunk_data["outputs"][0]["text"]
stop_reason = chunk_data.get("stop_reason", None)
- if stop_reason != None:
+ if stop_reason is not None:
is_finished = True
finish_reason = stop_reason
######## bedrock.cohere mappings ###############
@@ -1075,11 +1958,12 @@ class AWSEventStreamDecoder:
is_finished = True
finish_reason = chunk_data["completionReason"]
return GenericStreamingChunk(
- **{
- "text": text,
- "is_finished": is_finished,
- "finish_reason": finish_reason,
- }
+ text=text,
+ is_finished=is_finished,
+ finish_reason=finish_reason,
+ usage=None,
+ index=0,
+ tool_use=None,
)
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
@@ -1116,9 +2000,14 @@ class AWSEventStreamDecoder:
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")
+ if "chunk" in parsed_response:
+ chunk = parsed_response.get("chunk")
+ if not chunk:
+ return None
+ return chunk.get("bytes").decode() # type: ignore[no-any-return]
+ else:
+ chunk = response_dict.get("body")
+ if not chunk:
+ return None
- chunk = parsed_response.get("chunk")
- if not chunk:
- return None
-
- return chunk.get("bytes").decode() # type: ignore[no-any-return]
+ return chunk.decode() # type: ignore[no-any-return]
diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py
index 4df25944b..5ec9c79bb 100644
--- a/litellm/llms/custom_httpx/http_handler.py
+++ b/litellm/llms/custom_httpx/http_handler.py
@@ -1,4 +1,5 @@
-import httpx, asyncio
+import litellm
+import httpx, asyncio, traceback, os
from typing import Optional, Union, Mapping, Any
# https://www.python-httpx.org/advanced/timeouts
@@ -11,6 +12,30 @@ class AsyncHTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None,
concurrent_limit=1000,
):
+ async_proxy_mounts = None
+ # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
+ http_proxy = os.getenv("HTTP_PROXY", None)
+ https_proxy = os.getenv("HTTPS_PROXY", None)
+ no_proxy = os.getenv("NO_PROXY", None)
+ ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify))
+ cert = os.getenv(
+ "SSL_CERTIFICATE", litellm.ssl_certificate
+ ) # /path/to/client.pem
+
+ if http_proxy is not None and https_proxy is not None:
+ async_proxy_mounts = {
+ "http://": httpx.AsyncHTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
+ "https://": httpx.AsyncHTTPTransport(
+ proxy=httpx.Proxy(url=https_proxy)
+ ),
+ }
+ # assume no_proxy is a list of comma separated urls
+ if no_proxy is not None and isinstance(no_proxy, str):
+ no_proxy_urls = no_proxy.split(",")
+
+ for url in no_proxy_urls: # set no-proxy support for specific urls
+ async_proxy_mounts[url] = None # type: ignore
+
if timeout is None:
timeout = _DEFAULT_TIMEOUT
# Create a client with a connection pool
@@ -20,6 +45,9 @@ class AsyncHTTPHandler:
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
+ verify=ssl_verify,
+ mounts=async_proxy_mounts,
+ cert=cert,
)
async def close(self):
@@ -43,15 +71,22 @@ class AsyncHTTPHandler:
self,
url: str,
data: Optional[Union[dict, str]] = None, # type: ignore
+ json: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
- req = self.client.build_request(
- "POST", url, data=data, params=params, headers=headers # type: ignore
- )
- response = await self.client.send(req, stream=stream)
- return response
+ try:
+ req = self.client.build_request(
+ "POST", url, data=data, json=json, params=params, headers=headers # type: ignore
+ )
+ response = await self.client.send(req, stream=stream)
+ response.raise_for_status()
+ return response
+ except httpx.HTTPStatusError as e:
+ raise e
+ except Exception as e:
+ raise e
def __del__(self) -> None:
try:
@@ -70,6 +105,28 @@ class HTTPHandler:
if timeout is None:
timeout = _DEFAULT_TIMEOUT
+ # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
+ http_proxy = os.getenv("HTTP_PROXY", None)
+ https_proxy = os.getenv("HTTPS_PROXY", None)
+ no_proxy = os.getenv("NO_PROXY", None)
+ ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify))
+ cert = os.getenv(
+ "SSL_CERTIFICATE", litellm.ssl_certificate
+ ) # /path/to/client.pem
+
+ sync_proxy_mounts = None
+ if http_proxy is not None and https_proxy is not None:
+ sync_proxy_mounts = {
+ "http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
+ "https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)),
+ }
+ # assume no_proxy is a list of comma separated urls
+ if no_proxy is not None and isinstance(no_proxy, str):
+ no_proxy_urls = no_proxy.split(",")
+
+ for url in no_proxy_urls: # set no-proxy support for specific urls
+ sync_proxy_mounts[url] = None # type: ignore
+
if client is None:
# Create a client with a connection pool
self.client = httpx.Client(
@@ -78,6 +135,9 @@ class HTTPHandler:
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
+ verify=ssl_verify,
+ mounts=sync_proxy_mounts,
+ cert=cert,
)
else:
self.client = client
@@ -96,12 +156,13 @@ class HTTPHandler:
self,
url: str,
data: Optional[Union[dict, str]] = None,
+ json: Optional[Union[dict, str]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
req = self.client.build_request(
- "POST", url, data=data, params=params, headers=headers # type: ignore
+ "POST", url, data=data, json=json, params=params, headers=headers # type: ignore
)
response = self.client.send(req, stream=stream)
return response
diff --git a/litellm/llms/databricks.py b/litellm/llms/databricks.py
index 7b2013710..4fe475259 100644
--- a/litellm/llms/databricks.py
+++ b/litellm/llms/databricks.py
@@ -1,5 +1,6 @@
# What is this?
## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
+from functools import partial
import os, types
import json
from enum import Enum
@@ -123,7 +124,7 @@ class DatabricksConfig:
original_chunk = None # this is used for function/tool calling
chunk_data = chunk_data.replace("data:", "")
chunk_data = chunk_data.strip()
- if len(chunk_data) == 0:
+ if len(chunk_data) == 0 or chunk_data == "[DONE]":
return {
"text": "",
"is_finished": is_finished,
@@ -221,6 +222,32 @@ class DatabricksEmbeddingConfig:
return optional_params
+async def make_call(
+ client: AsyncHTTPHandler,
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+):
+ response = await client.post(api_base, headers=headers, data=data, stream=True)
+
+ if response.status_code != 200:
+ raise DatabricksError(status_code=response.status_code, message=response.text)
+
+ completion_stream = response.aiter_lines()
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=completion_stream, # Pass the completion stream for logging
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
class DatabricksChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
@@ -354,29 +381,21 @@ class DatabricksChatCompletion(BaseLLM):
litellm_params=None,
logger_fn=None,
headers={},
- ):
- self.async_handler = AsyncHTTPHandler(
- timeout=httpx.Timeout(timeout=600.0, connect=5.0)
- )
+ client: Optional[AsyncHTTPHandler] = None,
+ ) -> CustomStreamWrapper:
+
data["stream"] = True
- try:
- response = await self.async_handler.post(
- api_base, headers=headers, data=json.dumps(data), stream=True
- )
- response.raise_for_status()
-
- completion_stream = response.aiter_lines()
- except httpx.HTTPStatusError as e:
- raise DatabricksError(
- status_code=e.response.status_code, message=response.text
- )
- except httpx.TimeoutException as e:
- raise DatabricksError(status_code=408, message="Timeout error occurred.")
- except Exception as e:
- raise DatabricksError(status_code=500, message=str(e))
-
streamwrapper = CustomStreamWrapper(
- completion_stream=completion_stream,
+ completion_stream=None,
+ make_call=partial(
+ make_call,
+ api_base=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ ),
model=model,
custom_llm_provider="databricks",
logging_obj=logging_obj,
@@ -475,6 +494,8 @@ class DatabricksChatCompletion(BaseLLM):
},
)
if acompletion == True:
+ if client is not None and isinstance(client, HTTPHandler):
+ client = None
if (
stream is not None and stream == True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
@@ -496,6 +517,7 @@ class DatabricksChatCompletion(BaseLLM):
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
+ client=client,
)
else:
return self.acompletion_function(
diff --git a/litellm/llms/gemini.py b/litellm/llms/gemini.py
index a55b39aef..cfdf39eca 100644
--- a/litellm/llms/gemini.py
+++ b/litellm/llms/gemini.py
@@ -1,13 +1,14 @@
-import os, types, traceback, copy, asyncio
-import json
-from enum import Enum
+import types
+import traceback
+import copy
import time
from typing import Callable, Optional
-from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
+from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
-import sys, httpx
+import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
from packaging.version import Version
+from litellm import verbose_logger
class GeminiError(Exception):
@@ -264,7 +265,8 @@ def completion(
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
- traceback.print_exc()
+ verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
+ verbose_logger.debug(traceback.format_exc())
raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code
)
@@ -356,7 +358,8 @@ async def async_completion(
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
- traceback.print_exc()
+ verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
+ verbose_logger.debug(traceback.format_exc())
raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code
)
diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py
index 9c9b5e898..e7dd1d5f5 100644
--- a/litellm/llms/ollama.py
+++ b/litellm/llms/ollama.py
@@ -2,10 +2,12 @@ from itertools import chain
import requests, types, time # type: ignore
import json, uuid
import traceback
-from typing import Optional
+from typing import Optional, List
import litellm
+from litellm.types.utils import ProviderField
import httpx, aiohttp, asyncio # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt
+from litellm import verbose_logger
class OllamaError(Exception):
@@ -45,6 +47,8 @@ class OllamaConfig:
- `temperature` (float): The temperature of the model. Increasing the temperature will make the model answer more creatively. Default: 0.8. Example usage: temperature 0.7
+ - `seed` (int): Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. Example usage: seed 42
+
- `stop` (string[]): Sets the stop sequences to use. Example usage: stop "AI assistant:"
- `tfs_z` (float): Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. Default: 1. Example usage: tfs_z 1
@@ -69,6 +73,7 @@ class OllamaConfig:
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
+ seed: Optional[int] = None
stop: Optional[list] = (
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
)
@@ -90,6 +95,7 @@ class OllamaConfig:
repeat_last_n: Optional[int] = None,
repeat_penalty: Optional[float] = None,
temperature: Optional[float] = None,
+ seed: Optional[int] = None,
stop: Optional[list] = None,
tfs_z: Optional[float] = None,
num_predict: Optional[int] = None,
@@ -120,6 +126,59 @@ class OllamaConfig:
)
and v is not None
}
+
+ def get_required_params(self) -> List[ProviderField]:
+ """For a given provider, return it's required fields with a description"""
+ return [
+ ProviderField(
+ field_name="base_url",
+ field_type="string",
+ field_description="Your Ollama API Base",
+ field_value="http://10.10.11.249:11434",
+ )
+ ]
+
+
+ def get_supported_openai_params(
+ self,
+ ):
+ return [
+ "max_tokens",
+ "stream",
+ "top_p",
+ "temperature",
+ "seed",
+ "frequency_penalty",
+ "stop",
+ "response_format",
+ ]
+
+
+# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
+# and convert to jpeg if necessary.
+def _convert_image(image):
+ import base64, io
+
+ try:
+ from PIL import Image
+ except:
+ raise Exception(
+ "ollama image conversion failed please run `pip install Pillow`"
+ )
+
+ orig = image
+ if image.startswith("data:"):
+ image = image.split(",")[-1]
+ try:
+ image_data = Image.open(io.BytesIO(base64.b64decode(image)))
+ if image_data.format in ["JPEG", "PNG"]:
+ return image
+ except:
+ return orig
+ jpeg_image = io.BytesIO()
+ image_data.convert("RGB").save(jpeg_image, "JPEG")
+ jpeg_image.seek(0)
+ return base64.b64encode(jpeg_image.getvalue()).decode("utf-8")
# ollama implementation
@@ -158,7 +217,7 @@ def get_ollama_response(
if format is not None:
data["format"] = format
if images is not None:
- data["images"] = images
+ data["images"] = [_convert_image(image) for image in images]
## LOGGING
logging_obj.pre_call(
@@ -349,7 +408,13 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
- traceback.print_exc()
+ verbose_logger.error(
+ "LiteLLM.ollama.py::ollama_async_streaming(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_logger.debug(traceback.format_exc())
+
raise e
@@ -413,7 +478,12 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
)
return model_response
except Exception as e:
- traceback.print_exc()
+ verbose_logger.error(
+ "LiteLLM.ollama.py::ollama_acompletion(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_logger.debug(traceback.format_exc())
raise e
diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py
index d1ff4953f..a7439bbcc 100644
--- a/litellm/llms/ollama_chat.py
+++ b/litellm/llms/ollama_chat.py
@@ -1,11 +1,15 @@
from itertools import chain
-import requests, types, time
-import json, uuid
+import requests
+import types
+import time
+import json
+import uuid
import traceback
from typing import Optional
+from litellm import verbose_logger
import litellm
-import httpx, aiohttp, asyncio
-from .prompt_templates.factory import prompt_factory, custom_prompt
+import httpx
+import aiohttp
class OllamaError(Exception):
@@ -45,6 +49,8 @@ class OllamaChatConfig:
- `temperature` (float): The temperature of the model. Increasing the temperature will make the model answer more creatively. Default: 0.8. Example usage: temperature 0.7
+ - `seed` (int): Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. Example usage: seed 42
+
- `stop` (string[]): Sets the stop sequences to use. Example usage: stop "AI assistant:"
- `tfs_z` (float): Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. Default: 1. Example usage: tfs_z 1
@@ -69,6 +75,7 @@ class OllamaChatConfig:
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
+ seed: Optional[int] = None
stop: Optional[list] = (
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
)
@@ -90,6 +97,7 @@ class OllamaChatConfig:
repeat_last_n: Optional[int] = None,
repeat_penalty: Optional[float] = None,
temperature: Optional[float] = None,
+ seed: Optional[int] = None,
stop: Optional[list] = None,
tfs_z: Optional[float] = None,
num_predict: Optional[int] = None,
@@ -130,6 +138,7 @@ class OllamaChatConfig:
"stream",
"top_p",
"temperature",
+ "seed",
"frequency_penalty",
"stop",
"tools",
@@ -146,6 +155,8 @@ class OllamaChatConfig:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
+ if param == "seed":
+ optional_params["seed"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "frequency_penalty":
@@ -292,7 +303,10 @@ def get_ollama_response(
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
- "function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
+ "function": {
+ "name": function_call["name"],
+ "arguments": json.dumps(function_call["arguments"]),
+ },
"type": "function",
}
],
@@ -300,7 +314,9 @@ def get_ollama_response(
model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls"
else:
- model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
+ model_response["choices"][0]["message"]["content"] = response_json["message"][
+ "content"
+ ]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
@@ -354,7 +370,10 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
- "function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
+ "function": {
+ "name": function_call["name"],
+ "arguments": json.dumps(function_call["arguments"]),
+ },
"type": "function",
}
],
@@ -403,9 +422,10 @@ async def ollama_async_streaming(
first_chunk_content = first_chunk.choices[0].delta.content or ""
response_content = first_chunk_content + "".join(
[
- chunk.choices[0].delta.content
- async for chunk in streamwrapper
- if chunk.choices[0].delta.content]
+ chunk.choices[0].delta.content
+ async for chunk in streamwrapper
+ if chunk.choices[0].delta.content
+ ]
)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
@@ -413,7 +433,10 @@ async def ollama_async_streaming(
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
- "function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
+ "function": {
+ "name": function_call["name"],
+ "arguments": json.dumps(function_call["arguments"]),
+ },
"type": "function",
}
],
@@ -426,7 +449,8 @@ async def ollama_async_streaming(
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
- traceback.print_exc()
+ verbose_logger.error("LiteLLM.gemini(): Exception occured - {}".format(str(e)))
+ verbose_logger.debug(traceback.format_exc())
async def ollama_acompletion(
@@ -476,7 +500,10 @@ async def ollama_acompletion(
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
- "function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
+ "function": {
+ "name": function_call["name"],
+ "arguments": json.dumps(function_call["arguments"]),
+ },
"type": "function",
}
],
@@ -484,7 +511,9 @@ async def ollama_acompletion(
model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls"
else:
- model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
+ model_response["choices"][0]["message"]["content"] = response_json[
+ "message"
+ ]["content"]
model_response["created"] = int(time.time())
model_response["model"] = "ollama_chat/" + data["model"]
@@ -502,5 +531,9 @@ async def ollama_acompletion(
)
return model_response
except Exception as e:
- traceback.print_exc()
+ verbose_logger.error(
+ "LiteLLM.ollama_acompletion(): Exception occured - {}".format(str(e))
+ )
+ verbose_logger.debug(traceback.format_exc())
+
raise e
diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py
index 6197ec922..dec86d35d 100644
--- a/litellm/llms/openai.py
+++ b/litellm/llms/openai.py
@@ -6,7 +6,8 @@ from typing import (
Literal,
Iterable,
)
-from typing_extensions import override
+import hashlib
+from typing_extensions import override, overload
from pydantic import BaseModel
import types, time, json, traceback
import httpx
@@ -21,11 +22,12 @@ from litellm.utils import (
TranscriptionResponse,
TextCompletionResponse,
)
-from typing import Callable, Optional
+from typing import Callable, Optional, Coroutine
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI
from ..types.llms.openai import *
+import openai
class OpenAIError(Exception):
@@ -224,6 +226,7 @@ class DeepInfraConfig:
def get_supported_openai_params(self):
return [
+ "stream",
"frequency_penalty",
"function_call",
"functions",
@@ -348,7 +351,6 @@ class OpenAIConfig:
"top_p",
"tools",
"tool_choice",
- "user",
"function_call",
"functions",
"max_retries",
@@ -361,6 +363,12 @@ class OpenAIConfig:
): # gpt-4 does not support 'response_format'
model_specific_params.append("response_format")
+ if (
+ model in litellm.open_ai_chat_completion_models
+ ) or model in litellm.open_ai_text_completion_models:
+ model_specific_params.append(
+ "user"
+ ) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
return base_params + model_specific_params
def map_openai_params(
@@ -497,6 +505,64 @@ class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
+ def _get_openai_client(
+ self,
+ is_async: bool,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
+ max_retries: Optional[int] = None,
+ organization: Optional[str] = None,
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ ):
+ args = locals()
+ if client is None:
+ if not isinstance(max_retries, int):
+ raise OpenAIError(
+ status_code=422,
+ message="max retries must be an int. Passed in value: {}".format(
+ max_retries
+ ),
+ )
+ # Creating a new OpenAI Client
+ # check in memory cache before creating a new one
+ # Convert the API key to bytes
+ hashed_api_key = None
+ if api_key is not None:
+ hash_object = hashlib.sha256(api_key.encode())
+ # Hexadecimal representation of the hash
+ hashed_api_key = hash_object.hexdigest()
+
+ _cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization},is_async={is_async}"
+
+ if _cache_key in litellm.in_memory_llm_clients_cache:
+ return litellm.in_memory_llm_clients_cache[_cache_key]
+ if is_async:
+ _new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ http_client=litellm.aclient_session,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ )
+ else:
+ _new_client = OpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ http_client=litellm.client_session,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ )
+
+ ## SAVE CACHE KEY
+ litellm.in_memory_llm_clients_cache[_cache_key] = _new_client
+ return _new_client
+
+ else:
+ return client
+
def completion(
self,
model_response: ModelResponse,
@@ -603,17 +669,16 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError(
status_code=422, message="max retries must be an int"
)
- if client is None:
- openai_client = OpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.client_session,
- timeout=timeout,
- max_retries=max_retries,
- organization=organization,
- )
- else:
- openai_client = client
+
+ openai_client = self._get_openai_client(
+ is_async=False,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
## LOGGING
logging_obj.pre_call(
@@ -693,17 +758,15 @@ class OpenAIChatCompletion(BaseLLM):
):
response = None
try:
- if client is None:
- openai_aclient = AsyncOpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.aclient_session,
- timeout=timeout,
- max_retries=max_retries,
- organization=organization,
- )
- else:
- openai_aclient = client
+ openai_aclient = self._get_openai_client(
+ is_async=True,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
## LOGGING
logging_obj.pre_call(
@@ -747,17 +810,15 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=None,
headers=None,
):
- if client is None:
- openai_client = OpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.client_session,
- timeout=timeout,
- max_retries=max_retries,
- organization=organization,
- )
- else:
- openai_client = client
+ openai_client = self._get_openai_client(
+ is_async=False,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
## LOGGING
logging_obj.pre_call(
input=data["messages"],
@@ -794,17 +855,15 @@ class OpenAIChatCompletion(BaseLLM):
):
response = None
try:
- if client is None:
- openai_aclient = AsyncOpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.aclient_session,
- timeout=timeout,
- max_retries=max_retries,
- organization=organization,
- )
- else:
- openai_aclient = client
+ openai_aclient = self._get_openai_client(
+ is_async=True,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
## LOGGING
logging_obj.pre_call(
input=data["messages"],
@@ -858,16 +917,14 @@ class OpenAIChatCompletion(BaseLLM):
):
response = None
try:
- if client is None:
- openai_aclient = AsyncOpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.aclient_session,
- timeout=timeout,
- max_retries=max_retries,
- )
- else:
- openai_aclient = client
+ openai_aclient = self._get_openai_client(
+ is_async=True,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump()
## LOGGING
@@ -915,19 +972,18 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data, "api_base": api_base},
)
- if aembedding == True:
+ if aembedding is True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
- if client is None:
- openai_client = OpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.client_session,
- timeout=timeout,
- max_retries=max_retries,
- )
- else:
- openai_client = client
+
+ openai_client = self._get_openai_client(
+ is_async=False,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
## COMPLETION CALL
response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore
@@ -963,16 +1019,16 @@ class OpenAIChatCompletion(BaseLLM):
):
response = None
try:
- if client is None:
- openai_aclient = AsyncOpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.aclient_session,
- timeout=timeout,
- max_retries=max_retries,
- )
- else:
- openai_aclient = client
+
+ openai_aclient = self._get_openai_client(
+ is_async=True,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump()
## LOGGING
@@ -1017,16 +1073,14 @@ class OpenAIChatCompletion(BaseLLM):
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
- if client is None:
- openai_client = OpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.client_session,
- timeout=timeout,
- max_retries=max_retries,
- )
- else:
- openai_client = client
+ openai_client = self._get_openai_client(
+ is_async=False,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
## LOGGING
logging_obj.pre_call(
@@ -1084,14 +1138,14 @@ class OpenAIChatCompletion(BaseLLM):
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
- api_key: Optional[str] = None,
- api_base: Optional[str] = None,
+ api_key: Optional[str],
+ api_base: Optional[str],
client=None,
logging_obj=None,
atranscription: bool = False,
):
data = {"model": model, "file": audio_file, **optional_params}
- if atranscription == True:
+ if atranscription is True:
return self.async_audio_transcriptions(
audio_file=audio_file,
data=data,
@@ -1103,16 +1157,14 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=max_retries,
logging_obj=logging_obj,
)
- if client is None:
- openai_client = OpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.client_session,
- timeout=timeout,
- max_retries=max_retries,
- )
- else:
- openai_client = client
+
+ openai_client = self._get_openai_client(
+ is_async=False,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ )
response = openai_client.audio.transcriptions.create(
**data, timeout=timeout # type: ignore
)
@@ -1141,18 +1193,16 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=None,
logging_obj=None,
):
- response = None
try:
- if client is None:
- openai_aclient = AsyncOpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.aclient_session,
- timeout=timeout,
- max_retries=max_retries,
- )
- else:
- openai_aclient = client
+ openai_aclient = self._get_openai_client(
+ is_async=True,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
response = await openai_aclient.audio.transcriptions.create(
**data, timeout=timeout
) # type: ignore
@@ -1175,6 +1225,87 @@ class OpenAIChatCompletion(BaseLLM):
)
raise e
+ def audio_speech(
+ self,
+ model: str,
+ input: str,
+ voice: str,
+ optional_params: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ organization: Optional[str],
+ project: Optional[str],
+ max_retries: int,
+ timeout: Union[float, httpx.Timeout],
+ aspeech: Optional[bool] = None,
+ client=None,
+ ) -> HttpxBinaryResponseContent:
+
+ if aspeech is not None and aspeech is True:
+ return self.async_audio_speech(
+ model=model,
+ input=input,
+ voice=voice,
+ optional_params=optional_params,
+ api_key=api_key,
+ api_base=api_base,
+ organization=organization,
+ project=project,
+ max_retries=max_retries,
+ timeout=timeout,
+ client=client,
+ ) # type: ignore
+
+ openai_client = self._get_openai_client(
+ is_async=False,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = openai_client.audio.speech.create(
+ model=model,
+ voice=voice, # type: ignore
+ input=input,
+ **optional_params,
+ )
+ return response
+
+ async def async_audio_speech(
+ self,
+ model: str,
+ input: str,
+ voice: str,
+ optional_params: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ organization: Optional[str],
+ project: Optional[str],
+ max_retries: int,
+ timeout: Union[float, httpx.Timeout],
+ client=None,
+ ) -> HttpxBinaryResponseContent:
+
+ openai_client = self._get_openai_client(
+ is_async=True,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = await openai_client.audio.speech.create(
+ model=model,
+ voice=voice, # type: ignore
+ input=input,
+ **optional_params,
+ )
+
+ return response
+
async def ahealth_check(
self,
model: Optional[str],
@@ -1496,6 +1627,322 @@ class OpenAITextCompletion(BaseLLM):
yield transformed_chunk
+class OpenAIFilesAPI(BaseLLM):
+ """
+ OpenAI methods to support for batches
+ - create_file()
+ - retrieve_file()
+ - list_files()
+ - delete_file()
+ - file_content()
+ - update_file()
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def get_openai_client(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ _is_async: bool = False,
+ ) -> Optional[Union[OpenAI, AsyncOpenAI]]:
+ received_args = locals()
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
+ if client is None:
+ data = {}
+ for k, v in received_args.items():
+ if k == "self" or k == "client" or k == "_is_async":
+ pass
+ elif k == "api_base" and v is not None:
+ data["base_url"] = v
+ elif v is not None:
+ data[k] = v
+ if _is_async is True:
+ openai_client = AsyncOpenAI(**data)
+ else:
+ openai_client = OpenAI(**data) # type: ignore
+ else:
+ openai_client = client
+
+ return openai_client
+
+ async def acreate_file(
+ self,
+ create_file_data: CreateFileRequest,
+ openai_client: AsyncOpenAI,
+ ) -> FileObject:
+ response = await openai_client.files.create(**create_file_data)
+ return response
+
+ def create_file(
+ self,
+ _is_async: bool,
+ create_file_data: CreateFileRequest,
+ api_base: str,
+ api_key: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ ) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+
+ if _is_async is True:
+ if not isinstance(openai_client, AsyncOpenAI):
+ raise ValueError(
+ "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
+ )
+ return self.acreate_file( # type: ignore
+ create_file_data=create_file_data, openai_client=openai_client
+ )
+ response = openai_client.files.create(**create_file_data)
+ return response
+
+ async def afile_content(
+ self,
+ file_content_request: FileContentRequest,
+ openai_client: AsyncOpenAI,
+ ) -> HttpxBinaryResponseContent:
+ response = await openai_client.files.content(**file_content_request)
+ return response
+
+ def file_content(
+ self,
+ _is_async: bool,
+ file_content_request: FileContentRequest,
+ api_base: str,
+ api_key: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ ) -> Union[
+ HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
+ ]:
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+
+ if _is_async is True:
+ if not isinstance(openai_client, AsyncOpenAI):
+ raise ValueError(
+ "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
+ )
+ return self.afile_content( # type: ignore
+ file_content_request=file_content_request,
+ openai_client=openai_client,
+ )
+ response = openai_client.files.content(**file_content_request)
+
+ return response
+
+
+class OpenAIBatchesAPI(BaseLLM):
+ """
+ OpenAI methods to support for batches
+ - create_batch()
+ - retrieve_batch()
+ - cancel_batch()
+ - list_batch()
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def get_openai_client(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ _is_async: bool = False,
+ ) -> Optional[Union[OpenAI, AsyncOpenAI]]:
+ received_args = locals()
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
+ if client is None:
+ data = {}
+ for k, v in received_args.items():
+ if k == "self" or k == "client" or k == "_is_async":
+ pass
+ elif k == "api_base" and v is not None:
+ data["base_url"] = v
+ elif v is not None:
+ data[k] = v
+ if _is_async is True:
+ openai_client = AsyncOpenAI(**data)
+ else:
+ openai_client = OpenAI(**data) # type: ignore
+ else:
+ openai_client = client
+
+ return openai_client
+
+ async def acreate_batch(
+ self,
+ create_batch_data: CreateBatchRequest,
+ openai_client: AsyncOpenAI,
+ ) -> Batch:
+ response = await openai_client.batches.create(**create_batch_data)
+ return response
+
+ def create_batch(
+ self,
+ _is_async: bool,
+ create_batch_data: CreateBatchRequest,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ ) -> Union[Batch, Coroutine[Any, Any, Batch]]:
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+
+ if _is_async is True:
+ if not isinstance(openai_client, AsyncOpenAI):
+ raise ValueError(
+ "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
+ )
+ return self.acreate_batch( # type: ignore
+ create_batch_data=create_batch_data, openai_client=openai_client
+ )
+ response = openai_client.batches.create(**create_batch_data)
+ return response
+
+ async def aretrieve_batch(
+ self,
+ retrieve_batch_data: RetrieveBatchRequest,
+ openai_client: AsyncOpenAI,
+ ) -> Batch:
+ response = await openai_client.batches.retrieve(**retrieve_batch_data)
+ return response
+
+ def retrieve_batch(
+ self,
+ _is_async: bool,
+ retrieve_batch_data: RetrieveBatchRequest,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI] = None,
+ ):
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+
+ if _is_async is True:
+ if not isinstance(openai_client, AsyncOpenAI):
+ raise ValueError(
+ "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
+ )
+ return self.aretrieve_batch( # type: ignore
+ retrieve_batch_data=retrieve_batch_data, openai_client=openai_client
+ )
+ response = openai_client.batches.retrieve(**retrieve_batch_data)
+ return response
+
+ def cancel_batch(
+ self,
+ _is_async: bool,
+ cancel_batch_data: CancelBatchRequest,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI] = None,
+ ):
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+ response = openai_client.batches.cancel(**cancel_batch_data)
+ return response
+
+ # def list_batch(
+ # self,
+ # list_batch_data: ListBatchRequest,
+ # api_key: Optional[str],
+ # api_base: Optional[str],
+ # timeout: Union[float, httpx.Timeout],
+ # max_retries: Optional[int],
+ # organization: Optional[str],
+ # client: Optional[OpenAI] = None,
+ # ):
+ # openai_client: OpenAI = self.get_openai_client(
+ # api_key=api_key,
+ # api_base=api_base,
+ # timeout=timeout,
+ # max_retries=max_retries,
+ # organization=organization,
+ # client=client,
+ # )
+ # response = openai_client.batches.list(**list_batch_data)
+ # return response
+
+
class OpenAIAssistantsAPI(BaseLLM):
def __init__(self) -> None:
super().__init__()
@@ -1525,8 +1972,85 @@ class OpenAIAssistantsAPI(BaseLLM):
return openai_client
+ def async_get_openai_client(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI] = None,
+ ) -> AsyncOpenAI:
+ received_args = locals()
+ if client is None:
+ data = {}
+ for k, v in received_args.items():
+ if k == "self" or k == "client":
+ pass
+ elif k == "api_base" and v is not None:
+ data["base_url"] = v
+ elif v is not None:
+ data[k] = v
+ openai_client = AsyncOpenAI(**data) # type: ignore
+ else:
+ openai_client = client
+
+ return openai_client
+
### ASSISTANTS ###
+ async def async_get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ ) -> AsyncCursorPage[Assistant]:
+ openai_client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+
+ response = await openai_client.beta.assistants.list()
+
+ return response
+
+ # fmt: off
+
+ @overload
+ def get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ aget_assistants: Literal[True],
+ ) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
+ ...
+
+ @overload
+ def get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI],
+ aget_assistants: Optional[Literal[False]],
+ ) -> SyncCursorPage[Assistant]:
+ ...
+
+ # fmt: on
+
def get_assistants(
self,
api_key: Optional[str],
@@ -1534,8 +2058,18 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
- client: Optional[OpenAI],
- ) -> SyncCursorPage[Assistant]:
+ client=None,
+ aget_assistants=None,
+ ):
+ if aget_assistants is not None and aget_assistants == True:
+ return self.async_get_assistants(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@@ -1551,18 +2085,95 @@ class OpenAIAssistantsAPI(BaseLLM):
### MESSAGES ###
- def add_message(
+ async def a_add_message(
self,
thread_id: str,
- message_data: MessageData,
+ message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
- client: Optional[OpenAI] = None,
+ client: Optional[AsyncOpenAI] = None,
) -> OpenAIMessage:
+ openai_client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+ thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
+ thread_id, **message_data # type: ignore
+ )
+
+ response_obj: Optional[OpenAIMessage] = None
+ if getattr(thread_message, "status", None) is None:
+ thread_message.status = "completed"
+ response_obj = OpenAIMessage(**thread_message.dict())
+ else:
+ response_obj = OpenAIMessage(**thread_message.dict())
+ return response_obj
+
+ # fmt: off
+
+ @overload
+ def add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ a_add_message: Literal[True],
+ ) -> Coroutine[None, None, OpenAIMessage]:
+ ...
+
+ @overload
+ def add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI],
+ a_add_message: Optional[Literal[False]],
+ ) -> OpenAIMessage:
+ ...
+
+ # fmt: on
+
+ def add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client=None,
+ a_add_message: Optional[bool] = None,
+ ):
+ if a_add_message is not None and a_add_message == True:
+ return self.a_add_message(
+ thread_id=thread_id,
+ message_data=message_data,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@@ -1584,6 +2195,61 @@ class OpenAIAssistantsAPI(BaseLLM):
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
+ async def async_get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI] = None,
+ ) -> AsyncCursorPage[OpenAIMessage]:
+ openai_client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+
+ response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
+
+ return response
+
+ # fmt: off
+
+ @overload
+ def get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ aget_messages: Literal[True],
+ ) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
+ ...
+
+ @overload
+ def get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI],
+ aget_messages: Optional[Literal[False]],
+ ) -> SyncCursorPage[OpenAIMessage]:
+ ...
+
+ # fmt: on
+
def get_messages(
self,
thread_id: str,
@@ -1592,8 +2258,19 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
- client: Optional[OpenAI] = None,
- ) -> SyncCursorPage[OpenAIMessage]:
+ client=None,
+ aget_messages=None,
+ ):
+ if aget_messages is not None and aget_messages == True:
+ return self.async_get_messages(
+ thread_id=thread_id,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@@ -1609,6 +2286,70 @@ class OpenAIAssistantsAPI(BaseLLM):
### THREADS ###
+ async def async_create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ ) -> Thread:
+ openai_client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+
+ data = {}
+ if messages is not None:
+ data["messages"] = messages # type: ignore
+ if metadata is not None:
+ data["metadata"] = metadata # type: ignore
+
+ message_thread = await openai_client.beta.threads.create(**data) # type: ignore
+
+ return Thread(**message_thread.dict())
+
+ # fmt: off
+
+ @overload
+ def create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ client: Optional[AsyncOpenAI],
+ acreate_thread: Literal[True],
+ ) -> Coroutine[None, None, Thread]:
+ ...
+
+ @overload
+ def create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ client: Optional[OpenAI],
+ acreate_thread: Optional[Literal[False]],
+ ) -> Thread:
+ ...
+
+ # fmt: on
+
def create_thread(
self,
metadata: Optional[dict],
@@ -1617,9 +2358,10 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
- client: Optional[OpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
- ) -> Thread:
+ client=None,
+ acreate_thread=None,
+ ):
"""
Here's an example:
```
@@ -1630,6 +2372,17 @@ class OpenAIAssistantsAPI(BaseLLM):
openai_api.create_thread(messages=[message])
```
"""
+ if acreate_thread is not None and acreate_thread == True:
+ return self.async_create_thread(
+ metadata=metadata,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ messages=messages,
+ )
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@@ -1649,6 +2402,61 @@ class OpenAIAssistantsAPI(BaseLLM):
return Thread(**message_thread.dict())
+ async def async_get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ ) -> Thread:
+ openai_client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+
+ response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
+
+ return Thread(**response.dict())
+
+ # fmt: off
+
+ @overload
+ def get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ aget_thread: Literal[True],
+ ) -> Coroutine[None, None, Thread]:
+ ...
+
+ @overload
+ def get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI],
+ aget_thread: Optional[Literal[False]],
+ ) -> Thread:
+ ...
+
+ # fmt: on
+
def get_thread(
self,
thread_id: str,
@@ -1657,8 +2465,19 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
- client: Optional[OpenAI],
- ) -> Thread:
+ client=None,
+ aget_thread=None,
+ ):
+ if aget_thread is not None and aget_thread == True:
+ return self.async_get_thread(
+ thread_id=thread_id,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@@ -1677,6 +2496,142 @@ class OpenAIAssistantsAPI(BaseLLM):
### RUNS ###
+ async def arun_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ ) -> Run:
+ openai_client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+
+ response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ )
+
+ return response
+
+ def async_run_thread_stream(
+ self,
+ client: AsyncOpenAI,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ tools: Optional[Iterable[AssistantToolParam]],
+ event_handler: Optional[AssistantEventHandler],
+ ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
+ data = {
+ "thread_id": thread_id,
+ "assistant_id": assistant_id,
+ "additional_instructions": additional_instructions,
+ "instructions": instructions,
+ "metadata": metadata,
+ "model": model,
+ "tools": tools,
+ }
+ if event_handler is not None:
+ data["event_handler"] = event_handler
+ return client.beta.threads.runs.stream(**data) # type: ignore
+
+ def run_thread_stream(
+ self,
+ client: OpenAI,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ tools: Optional[Iterable[AssistantToolParam]],
+ event_handler: Optional[AssistantEventHandler],
+ ) -> AssistantStreamManager[AssistantEventHandler]:
+ data = {
+ "thread_id": thread_id,
+ "assistant_id": assistant_id,
+ "additional_instructions": additional_instructions,
+ "instructions": instructions,
+ "metadata": metadata,
+ "model": model,
+ "tools": tools,
+ }
+ if event_handler is not None:
+ data["event_handler"] = event_handler
+ return client.beta.threads.runs.stream(**data) # type: ignore
+
+ # fmt: off
+
+ @overload
+ def run_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client,
+ arun_thread: Literal[True],
+ event_handler: Optional[AssistantEventHandler],
+ ) -> Coroutine[None, None, Run]:
+ ...
+
+ @overload
+ def run_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client,
+ arun_thread: Optional[Literal[False]],
+ event_handler: Optional[AssistantEventHandler],
+ ) -> Run:
+ ...
+
+ # fmt: on
+
def run_thread(
self,
thread_id: str,
@@ -1692,8 +2647,47 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
- client: Optional[OpenAI],
- ) -> Run:
+ client=None,
+ arun_thread=None,
+ event_handler: Optional[AssistantEventHandler] = None,
+ ):
+ if arun_thread is not None and arun_thread == True:
+ if stream is not None and stream == True:
+ _client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+ return self.async_run_thread_stream(
+ client=_client,
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ event_handler=event_handler,
+ )
+ return self.arun_thread(
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ stream=stream,
+ tools=tools,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@@ -1703,6 +2697,19 @@ class OpenAIAssistantsAPI(BaseLLM):
client=client,
)
+ if stream is not None and stream == True:
+ return self.run_thread_stream(
+ client=openai_client,
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ event_handler=event_handler,
+ )
+
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,
diff --git a/litellm/llms/palm.py b/litellm/llms/palm.py
index f15be43db..4d9953e77 100644
--- a/litellm/llms/palm.py
+++ b/litellm/llms/palm.py
@@ -1,11 +1,12 @@
-import os, types, traceback, copy
-import json
-from enum import Enum
+import types
+import traceback
+import copy
import time
from typing import Callable, Optional
-from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
+from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
-import sys, httpx
+import httpx
+from litellm import verbose_logger
class PalmError(Exception):
@@ -165,7 +166,10 @@ def completion(
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
- traceback.print_exc()
+ verbose_logger.error(
+ "litellm.llms.palm.py::completion(): Exception occured - {}".format(str(e))
+ )
+ verbose_logger.debug(traceback.format_exc())
raise PalmError(
message=traceback.format_exc(), status_code=response.status_code
)
diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py
index 1e7e1d334..66c28acee 100644
--- a/litellm/llms/predibase.py
+++ b/litellm/llms/predibase.py
@@ -1,8 +1,9 @@
# What is this?
## Controller file for Predibase Integration - https://predibase.com/
-
+from functools import partial
import os, types
+import traceback
import json
from enum import Enum
import requests, copy # type: ignore
@@ -51,6 +52,32 @@ class PredibaseError(Exception):
) # Call the base class constructor with the parameters it needs
+async def make_call(
+ client: AsyncHTTPHandler,
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+):
+ response = await client.post(api_base, headers=headers, data=data, stream=True)
+
+ if response.status_code != 200:
+ raise PredibaseError(status_code=response.status_code, message=response.text)
+
+ completion_stream = response.aiter_lines()
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=completion_stream, # Pass the completion stream for logging
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
class PredibaseConfig:
"""
Reference: https://docs.predibase.com/user-guide/inference/rest_api
@@ -126,11 +153,17 @@ class PredibaseChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
- def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
+ def _validate_environment(
+ self, api_key: Optional[str], user_headers: dict, tenant_id: Optional[str]
+ ) -> dict:
if api_key is None:
raise ValueError(
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
)
+ if tenant_id is None:
+ raise ValueError(
+ "Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=)`) or in env - `PREDIBASE_TENANT_ID`."
+ )
headers = {
"content-type": "application/json",
"Authorization": "Bearer {}".format(api_key),
@@ -210,12 +243,12 @@ class PredibaseChatCompletion(BaseLLM):
"details" in completion_response
and "tokens" in completion_response["details"]
):
- model_response.choices[0].finish_reason = completion_response[
- "details"
- ]["finish_reason"]
+ model_response.choices[0].finish_reason = map_finish_reason(
+ completion_response["details"]["finish_reason"]
+ )
sum_logprob = 0
for token in completion_response["details"]["tokens"]:
- if token["logprob"] != None:
+ if token["logprob"] is not None:
sum_logprob += token["logprob"]
model_response["choices"][0][
"message"
@@ -233,7 +266,7 @@ class PredibaseChatCompletion(BaseLLM):
):
sum_logprob = 0
for token in item["tokens"]:
- if token["logprob"] != None:
+ if token["logprob"] is not None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(
@@ -243,7 +276,7 @@ class PredibaseChatCompletion(BaseLLM):
else:
message_obj = Message(content=None)
choice_obj = Choices(
- finish_reason=item["finish_reason"],
+ finish_reason=map_finish_reason(item["finish_reason"]),
index=idx + 1,
message=message_obj,
)
@@ -253,10 +286,8 @@ class PredibaseChatCompletion(BaseLLM):
## CALCULATING USAGE
prompt_tokens = 0
try:
- prompt_tokens = len(
- encoding.encode(model_response["choices"][0]["message"]["content"])
- ) ##[TODO] use a model-specific tokenizer here
- except:
+ prompt_tokens = litellm.token_counter(messages=messages)
+ except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
output_text = model_response["choices"][0]["message"].get("content", "")
@@ -299,15 +330,17 @@ class PredibaseChatCompletion(BaseLLM):
logging_obj,
optional_params: dict,
tenant_id: str,
+ timeout: Union[float, httpx.Timeout],
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: dict = {},
) -> Union[ModelResponse, CustomStreamWrapper]:
- headers = self._validate_environment(api_key, headers)
+ headers = self._validate_environment(api_key, headers, tenant_id=tenant_id)
completion_url = ""
input_text = ""
base_url = "https://serving.app.predibase.com"
+
if "https" in model:
completion_url = model
elif api_base:
@@ -317,7 +350,7 @@ class PredibaseChatCompletion(BaseLLM):
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
- if optional_params.get("stream", False) == True:
+ if optional_params.get("stream", False) is True:
completion_url += "/generate_stream"
else:
completion_url += "/generate"
@@ -361,9 +394,9 @@ class PredibaseChatCompletion(BaseLLM):
},
)
## COMPLETION CALL
- if acompletion == True:
+ if acompletion is True:
### ASYNC STREAMING
- if stream == True:
+ if stream is True:
return self.async_streaming(
model=model,
messages=messages,
@@ -378,6 +411,7 @@ class PredibaseChatCompletion(BaseLLM):
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
+ timeout=timeout,
) # type: ignore
else:
### ASYNC COMPLETION
@@ -396,10 +430,11 @@ class PredibaseChatCompletion(BaseLLM):
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
+ timeout=timeout,
) # type: ignore
### SYNC STREAMING
- if stream == True:
+ if stream is True:
response = requests.post(
completion_url,
headers=headers,
@@ -420,7 +455,6 @@ class PredibaseChatCompletion(BaseLLM):
headers=headers,
data=json.dumps(data),
)
-
return self.process_response(
model=model,
response=response,
@@ -448,16 +482,26 @@ class PredibaseChatCompletion(BaseLLM):
stream,
data: dict,
optional_params: dict,
+ timeout: Union[float, httpx.Timeout],
litellm_params=None,
logger_fn=None,
headers={},
) -> ModelResponse:
- self.async_handler = AsyncHTTPHandler(
- timeout=httpx.Timeout(timeout=600.0, connect=5.0)
- )
- response = await self.async_handler.post(
- api_base, headers=headers, data=json.dumps(data)
- )
+
+ async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
+ try:
+ response = await async_handler.post(
+ api_base, headers=headers, data=json.dumps(data)
+ )
+ except httpx.HTTPStatusError as e:
+ raise PredibaseError(
+ status_code=e.response.status_code,
+ message="HTTPStatusError - {}".format(e.response.text),
+ )
+ except Exception as e:
+ raise PredibaseError(
+ status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
+ )
return self.process_response(
model=model,
response=response,
@@ -483,31 +527,25 @@ class PredibaseChatCompletion(BaseLLM):
api_key,
logging_obj,
data: dict,
+ timeout: Union[float, httpx.Timeout],
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
) -> CustomStreamWrapper:
- self.async_handler = AsyncHTTPHandler(
- timeout=httpx.Timeout(timeout=600.0, connect=5.0)
- )
data["stream"] = True
- response = await self.async_handler.post(
- url=api_base,
- headers=headers,
- data=json.dumps(data),
- stream=True,
- )
-
- if response.status_code != 200:
- raise PredibaseError(
- status_code=response.status_code, message=response.text
- )
-
- completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
- completion_stream=completion_stream,
+ completion_stream=None,
+ make_call=partial(
+ make_call,
+ api_base=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ ),
model=model,
custom_llm_provider="predibase",
logging_obj=logging_obj,
diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py
index 41ecb486c..6bf03b52d 100644
--- a/litellm/llms/prompt_templates/factory.py
+++ b/litellm/llms/prompt_templates/factory.py
@@ -3,14 +3,7 @@ import requests, traceback
import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, meta, BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment
-from typing import (
- Any,
- List,
- Mapping,
- MutableMapping,
- Optional,
- Sequence,
-)
+from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
import litellm
import litellm.types
from litellm.types.completion import (
@@ -24,7 +17,7 @@ from litellm.types.completion import (
import litellm.types.llms
from litellm.types.llms.anthropic import *
import uuid
-
+from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
import litellm.types.llms.vertex_ai
@@ -833,7 +826,7 @@ def anthropic_messages_pt_xml(messages: list):
) # either string or none
if messages[msg_i].get(
"tool_calls", []
- ): # support assistant tool invoke convertion
+ ): # support assistant tool invoke conversion
assistant_text += convert_to_anthropic_tool_invoke_xml( # type: ignore
messages[msg_i]["tool_calls"]
)
@@ -1224,7 +1217,7 @@ def anthropic_messages_pt(messages: list):
if messages[msg_i].get(
"tool_calls", []
- ): # support assistant tool invoke convertion
+ ): # support assistant tool invoke conversion
assistant_content.extend(
convert_to_anthropic_tool_invoke(messages[msg_i]["tool_calls"])
)
@@ -1460,9 +1453,7 @@ def _load_image_from_url(image_url):
try:
from PIL import Image
except:
- raise Exception(
- "gemini image conversion failed please run `pip install Pillow`"
- )
+ raise Exception("image conversion failed please run `pip install Pillow`")
from io import BytesIO
try:
@@ -1613,6 +1604,380 @@ def azure_text_pt(messages: list):
return prompt
+###### AMAZON BEDROCK #######
+
+from litellm.types.llms.bedrock import (
+ ToolResultContentBlock as BedrockToolResultContentBlock,
+ ToolResultBlock as BedrockToolResultBlock,
+ ToolConfigBlock as BedrockToolConfigBlock,
+ ToolUseBlock as BedrockToolUseBlock,
+ ImageSourceBlock as BedrockImageSourceBlock,
+ ImageBlock as BedrockImageBlock,
+ ContentBlock as BedrockContentBlock,
+ ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
+ ToolSpecBlock as BedrockToolSpecBlock,
+ ToolBlock as BedrockToolBlock,
+ ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
+)
+
+
+def get_image_details(image_url) -> Tuple[str, str]:
+ try:
+ import base64
+
+ # Send a GET request to the image URL
+ response = requests.get(image_url)
+ response.raise_for_status() # Raise an exception for HTTP errors
+
+ # Check the response's content type to ensure it is an image
+ content_type = response.headers.get("content-type")
+ if not content_type or "image" not in content_type:
+ raise ValueError(
+ f"URL does not point to a valid image (content-type: {content_type})"
+ )
+
+ # Convert the image content to base64 bytes
+ base64_bytes = base64.b64encode(response.content).decode("utf-8")
+
+ # Get mime-type
+ mime_type = content_type.split("/")[
+ 1
+ ] # Extract mime-type from content-type header
+
+ return base64_bytes, mime_type
+
+ except requests.RequestException as e:
+ raise Exception(f"Request failed: {e}")
+ except Exception as e:
+ raise e
+
+
+def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
+ if "base64" in image_url:
+ # Case 1: Images with base64 encoding
+ import base64, re
+
+ # base 64 is passed as data:image/jpeg;base64,
+ image_metadata, img_without_base_64 = image_url.split(",")
+
+ # read mime_type from img_without_base_64=data:image/jpeg;base64
+ # Extract MIME type using regular expression
+ mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
+ if mime_type_match:
+ mime_type = mime_type_match.group(1)
+ image_format = mime_type.split("/")[1]
+ else:
+ mime_type = "image/jpeg"
+ image_format = "jpeg"
+ _blob = BedrockImageSourceBlock(bytes=img_without_base_64)
+ supported_image_formats = (
+ litellm.AmazonConverseConfig().get_supported_image_types()
+ )
+ if image_format in supported_image_formats:
+ return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
+ else:
+ # Handle the case when the image format is not supported
+ raise ValueError(
+ "Unsupported image format: {}. Supported formats: {}".format(
+ image_format, supported_image_formats
+ )
+ )
+ elif "https:/" in image_url:
+ # Case 2: Images with direct links
+ image_bytes, image_format = get_image_details(image_url)
+ _blob = BedrockImageSourceBlock(bytes=image_bytes)
+ supported_image_formats = (
+ litellm.AmazonConverseConfig().get_supported_image_types()
+ )
+ if image_format in supported_image_formats:
+ return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
+ else:
+ # Handle the case when the image format is not supported
+ raise ValueError(
+ "Unsupported image format: {}. Supported formats: {}".format(
+ image_format, supported_image_formats
+ )
+ )
+ else:
+ raise ValueError(
+ "Unsupported image type. Expected either image url or base64 encoded string - \
+ e.g. 'data:image/jpeg;base64,'"
+ )
+
+
+def _convert_to_bedrock_tool_call_invoke(
+ tool_calls: list,
+) -> List[BedrockContentBlock]:
+ """
+ OpenAI tool invokes:
+ {
+ "role": "assistant",
+ "content": null,
+ "tool_calls": [
+ {
+ "id": "call_abc123",
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "arguments": "{\n\"location\": \"Boston, MA\"\n}"
+ }
+ }
+ ]
+ },
+ """
+ """
+ Bedrock tool invokes:
+ [
+ {
+ "role": "assistant",
+ "toolUse": {
+ "input": {"location": "Boston, MA", ..},
+ "name": "get_current_weather",
+ "toolUseId": "call_abc123"
+ }
+ }
+ ]
+ """
+ """
+ - json.loads argument
+ - extract name
+ - extract id
+ """
+
+ try:
+ _parts_list: List[BedrockContentBlock] = []
+ for tool in tool_calls:
+ if "function" in tool:
+ id = tool["id"]
+ name = tool["function"].get("name", "")
+ arguments = tool["function"].get("arguments", "")
+ arguments_dict = json.loads(arguments)
+ bedrock_tool = BedrockToolUseBlock(
+ input=arguments_dict, name=name, toolUseId=id
+ )
+ bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool)
+ _parts_list.append(bedrock_content_block)
+ return _parts_list
+ except Exception as e:
+ raise Exception(
+ "Unable to convert openai tool calls={} to bedrock tool calls. Received error={}".format(
+ tool_calls, str(e)
+ )
+ )
+
+
+def _convert_to_bedrock_tool_call_result(
+ message: dict,
+) -> BedrockMessageBlock:
+ """
+ OpenAI message with a tool result looks like:
+ {
+ "tool_call_id": "tool_1",
+ "role": "tool",
+ "name": "get_current_weather",
+ "content": "function result goes here",
+ },
+
+ OpenAI message with a function call result looks like:
+ {
+ "role": "function",
+ "name": "get_current_weather",
+ "content": "function result goes here",
+ }
+ """
+ """
+ Bedrock result looks like this:
+ {
+ "role": "user",
+ "content": [
+ {
+ "toolResult": {
+ "toolUseId": "tooluse_kZJMlvQmRJ6eAyJE5GIl7Q",
+ "content": [
+ {
+ "json": {
+ "song": "Elemental Hotel",
+ "artist": "8 Storey Hike"
+ }
+ }
+ ]
+ }
+ }
+ ]
+ }
+ """
+ """
+ -
+ """
+ content = message.get("content", "")
+ name = message.get("name", "")
+ id = message.get("tool_call_id", str(uuid.uuid4()))
+
+ tool_result_content_block = BedrockToolResultContentBlock(text=content)
+ tool_result = BedrockToolResultBlock(
+ content=[tool_result_content_block],
+ toolUseId=id,
+ )
+ content_block = BedrockContentBlock(toolResult=tool_result)
+
+ return BedrockMessageBlock(role="user", content=[content_block])
+
+
+def _bedrock_converse_messages_pt(messages: List) -> List[BedrockMessageBlock]:
+ """
+ Converts given messages from OpenAI format to Bedrock format
+
+ - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
+ - Please ensure that function response turn comes immediately after a function call turn
+ """
+
+ contents: List[BedrockMessageBlock] = []
+ msg_i = 0
+ while msg_i < len(messages):
+ user_content: List[BedrockContentBlock] = []
+ init_msg_i = msg_i
+ ## MERGE CONSECUTIVE USER CONTENT ##
+ while msg_i < len(messages) and messages[msg_i]["role"] == "user":
+ if isinstance(messages[msg_i]["content"], list):
+ _parts: List[BedrockContentBlock] = []
+ for element in messages[msg_i]["content"]:
+ if isinstance(element, dict):
+ if element["type"] == "text":
+ _part = BedrockContentBlock(text=element["text"])
+ _parts.append(_part)
+ elif element["type"] == "image_url":
+ image_url = element["image_url"]["url"]
+ _part = _process_bedrock_converse_image_block( # type: ignore
+ image_url=image_url
+ )
+ _parts.append(BedrockContentBlock(image=_part)) # type: ignore
+ user_content.extend(_parts)
+ else:
+ _part = BedrockContentBlock(text=messages[msg_i]["content"])
+ user_content.append(_part)
+
+ msg_i += 1
+
+ if user_content:
+ contents.append(BedrockMessageBlock(role="user", content=user_content))
+ assistant_content: List[BedrockContentBlock] = []
+ ## MERGE CONSECUTIVE ASSISTANT CONTENT ##
+ while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
+ if isinstance(messages[msg_i]["content"], list):
+ assistants_parts: List[BedrockContentBlock] = []
+ for element in messages[msg_i]["content"]:
+ if isinstance(element, dict):
+ if element["type"] == "text":
+ assistants_part = BedrockContentBlock(text=element["text"])
+ assistants_parts.append(assistants_part)
+ elif element["type"] == "image_url":
+ image_url = element["image_url"]["url"]
+ assistants_part = _process_bedrock_converse_image_block( # type: ignore
+ image_url=image_url
+ )
+ assistants_parts.append(
+ BedrockContentBlock(image=assistants_part) # type: ignore
+ )
+ assistant_content.extend(assistants_parts)
+ elif messages[msg_i].get(
+ "tool_calls", []
+ ): # support assistant tool invoke convertion
+ assistant_content.extend(
+ _convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"])
+ )
+ else:
+ assistant_text = (
+ messages[msg_i].get("content") or ""
+ ) # either string or none
+ if assistant_text:
+ assistant_content.append(BedrockContentBlock(text=assistant_text))
+
+ msg_i += 1
+
+ if assistant_content:
+ contents.append(
+ BedrockMessageBlock(role="assistant", content=assistant_content)
+ )
+
+ ## APPEND TOOL CALL MESSAGES ##
+ if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
+ tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
+ contents.append(tool_call_result)
+ msg_i += 1
+ if msg_i == init_msg_i: # prevent infinite loops
+ raise Exception(
+ "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
+ messages[msg_i]
+ )
+ )
+
+ return contents
+
+
+def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
+ """
+ OpenAI tools looks like:
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "description": "Get the current weather in a given location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
+ },
+ "required": ["location"],
+ },
+ }
+ }
+ ]
+ """
+ """
+ Bedrock toolConfig looks like:
+ "tools": [
+ {
+ "toolSpec": {
+ "name": "top_song",
+ "description": "Get the most popular song played on a radio station.",
+ "inputSchema": {
+ "json": {
+ "type": "object",
+ "properties": {
+ "sign": {
+ "type": "string",
+ "description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP."
+ }
+ },
+ "required": [
+ "sign"
+ ]
+ }
+ }
+ }
+ }
+ ]
+ """
+ tool_block_list: List[BedrockToolBlock] = []
+ for tool in tools:
+ parameters = tool.get("function", {}).get("parameters", None)
+ name = tool.get("function", {}).get("name", "")
+ description = tool.get("function", {}).get("description", "")
+ tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
+ tool_spec = BedrockToolSpecBlock(
+ inputSchema=tool_input_schema, name=name, description=description
+ )
+ tool_block = BedrockToolBlock(toolSpec=tool_spec)
+ tool_block_list.append(tool_block)
+
+ return tool_block_list
+
+
# Function call template
def function_call_prompt(messages: list, functions: list):
function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:"""
diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py
index 386d24f59..ce62e51e9 100644
--- a/litellm/llms/replicate.py
+++ b/litellm/llms/replicate.py
@@ -251,7 +251,7 @@ async def async_handle_prediction_response(
logs = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
print_verbose(f"replicate: polling endpoint: {prediction_url}")
- await asyncio.sleep(0.5)
+ await asyncio.sleep(0.5) # prevent replicate rate limit errors
response = await http_handler.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py
index dc185aef9..bd9cfaa8d 100644
--- a/litellm/llms/vertex_ai.py
+++ b/litellm/llms/vertex_ai.py
@@ -3,7 +3,7 @@ import json
from enum import Enum
import requests # type: ignore
import time
-from typing import Callable, Optional, Union, List, Literal
+from typing import Callable, Optional, Union, List, Literal, Any
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
@@ -12,6 +12,7 @@ from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_result,
convert_to_gemini_tool_call_invoke,
)
+from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type
class VertexAIError(Exception):
@@ -297,24 +298,31 @@ def _convert_gemini_role(role: str) -> Literal["user", "model"]:
def _process_gemini_image(image_url: str) -> PartType:
try:
+ # GCS URIs
if "gs://" in image_url:
- # Case 1: Images with Cloud Storage URIs
- # The supported MIME types for images include image/png and image/jpeg.
- part_mime = "image/png" if "png" in image_url else "image/jpeg"
- _file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
- return PartType(file_data=_file_data)
+ # Figure out file type
+ extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
+ extension = extension_with_dot[1:] # Ex: "png"
+
+ file_type = get_file_type_from_extension(extension)
+
+ # Validate the file type is supported by Gemini
+ if not is_gemini_1_5_accepted_file_type(file_type):
+ raise Exception(f"File type not supported by gemini - {file_type}")
+
+ mime_type = get_file_mime_type_for_file_type(file_type)
+ file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
+
+ return PartType(file_data=file_data)
+
+ # Direct links
elif "https:/" in image_url:
- # Case 2: Images with direct links
image = _load_image_from_url(image_url)
_blob = BlobType(data=image.data, mime_type=image._mime_type)
return PartType(inline_data=_blob)
- elif ".mp4" in image_url and "gs://" in image_url:
- # Case 3: Videos with Cloud Storage URIs
- part_mime = "video/mp4"
- _file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
- return PartType(file_data=_file_data)
+
+ # Base64 encoding
elif "base64" in image_url:
- # Case 4: Images with base64 encoding
import base64, re
# base 64 is passed as data:image/jpeg;base64,
@@ -390,7 +398,7 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
assistant_content.extend(_parts)
elif messages[msg_i].get(
"tool_calls", []
- ): # support assistant tool invoke convertion
+ ): # support assistant tool invoke conversion
assistant_content.extend(
convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"])
)
@@ -421,110 +429,17 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
return contents
-def _gemini_vision_convert_messages(messages: list):
- """
- Converts given messages for GPT-4 Vision to Gemini format.
+def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
+ _cache_key = f"{model}-{vertex_project}-{vertex_location}"
+ return _cache_key
- Args:
- messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type:
- - If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt.
- - If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images.
- Returns:
- tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
+def _get_client_from_cache(client_cache_key: str):
+ return litellm.in_memory_llm_clients_cache.get(client_cache_key, None)
- Raises:
- VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed.
- Exception: If any other exception occurs during the execution of the function.
- Note:
- This function is based on the code from the 'gemini/getting-started/intro_gemini_python.ipynb' notebook in the 'generative-ai' repository on GitHub.
- The supported MIME types for images include 'image/png' and 'image/jpeg'.
-
- Examples:
- >>> messages = [
- ... {"content": "Hello, world!"},
- ... {"content": [{"type": "text", "text": "This is a text message."}, {"type": "image_url", "image_url": "example.com/image.png"}]},
- ... ]
- >>> _gemini_vision_convert_messages(messages)
- ('Hello, world!This is a text message.', [, ])
- """
- try:
- import vertexai
- except:
- raise VertexAIError(
- status_code=400,
- message="vertexai import failed please run `pip install google-cloud-aiplatform`",
- )
- try:
- from vertexai.preview.language_models import (
- ChatModel,
- CodeChatModel,
- InputOutputTextPair,
- )
- from vertexai.language_models import TextGenerationModel, CodeGenerationModel
- from vertexai.preview.generative_models import (
- GenerativeModel,
- Part,
- GenerationConfig,
- Image,
- )
-
- # given messages for gpt-4 vision, convert them for gemini
- # https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
- prompt = ""
- images = []
- for message in messages:
- if isinstance(message["content"], str):
- prompt += message["content"]
- elif isinstance(message["content"], list):
- # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
- for element in message["content"]:
- if isinstance(element, dict):
- if element["type"] == "text":
- prompt += element["text"]
- elif element["type"] == "image_url":
- image_url = element["image_url"]["url"]
- images.append(image_url)
- # processing images passed to gemini
- processed_images = []
- for img in images:
- if "gs://" in img:
- # Case 1: Images with Cloud Storage URIs
- # The supported MIME types for images include image/png and image/jpeg.
- part_mime = "image/png" if "png" in img else "image/jpeg"
- google_clooud_part = Part.from_uri(img, mime_type=part_mime)
- processed_images.append(google_clooud_part)
- elif "https:/" in img:
- # Case 2: Images with direct links
- image = _load_image_from_url(img)
- processed_images.append(image)
- elif ".mp4" in img and "gs://" in img:
- # Case 3: Videos with Cloud Storage URIs
- part_mime = "video/mp4"
- google_clooud_part = Part.from_uri(img, mime_type=part_mime)
- processed_images.append(google_clooud_part)
- elif "base64" in img:
- # Case 4: Images with base64 encoding
- import base64, re
-
- # base 64 is passed as data:image/jpeg;base64,
- image_metadata, img_without_base_64 = img.split(",")
-
- # read mime_type from img_without_base_64=data:image/jpeg;base64
- # Extract MIME type using regular expression
- mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
-
- if mime_type_match:
- mime_type = mime_type_match.group(1)
- else:
- mime_type = "image/jpeg"
- decoded_img = base64.b64decode(img_without_base_64)
- processed_image = Part.from_data(data=decoded_img, mime_type=mime_type)
- processed_images.append(processed_image)
- return prompt, processed_images
- except Exception as e:
- raise e
+def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any):
+ litellm.in_memory_llm_clients_cache[client_cache_key] = vertex_llm_model
def completion(
@@ -580,23 +495,32 @@ def completion(
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
- if vertex_credentials is not None and isinstance(vertex_credentials, str):
- import google.oauth2.service_account
- json_obj = json.loads(vertex_credentials)
+ _cache_key = _get_client_cache_key(
+ model=model, vertex_project=vertex_project, vertex_location=vertex_location
+ )
+ _vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
- creds = google.oauth2.service_account.Credentials.from_service_account_info(
- json_obj,
- scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ if _vertex_llm_model_object is None:
+ if vertex_credentials is not None and isinstance(vertex_credentials, str):
+ import google.oauth2.service_account
+
+ json_obj = json.loads(vertex_credentials)
+
+ creds = (
+ google.oauth2.service_account.Credentials.from_service_account_info(
+ json_obj,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ )
+ else:
+ creds, _ = google.auth.default(quota_project_id=vertex_project)
+ print_verbose(
+ f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
+ )
+ vertexai.init(
+ project=vertex_project, location=vertex_location, credentials=creds
)
- else:
- creds, _ = google.auth.default(quota_project_id=vertex_project)
- print_verbose(
- f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
- )
- vertexai.init(
- project=vertex_project, location=vertex_location, credentials=creds
- )
## Load Config
config = litellm.VertexAIConfig.get_config()
@@ -620,9 +544,9 @@ def completion(
prompt = " ".join(
[
- message["content"]
+ message.get("content")
for message in messages
- if isinstance(message["content"], str)
+ if isinstance(message.get("content", None), str)
]
)
@@ -639,23 +563,27 @@ def completion(
model in litellm.vertex_language_models
or model in litellm.vertex_vision_models
):
- llm_model = GenerativeModel(model)
+ llm_model = _vertex_llm_model_object or GenerativeModel(model)
mode = "vision"
request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_chat_models:
- llm_model = ChatModel.from_pretrained(model)
+ llm_model = _vertex_llm_model_object or ChatModel.from_pretrained(model)
mode = "chat"
request_str += f"llm_model = ChatModel.from_pretrained({model})\n"
elif model in litellm.vertex_text_models:
- llm_model = TextGenerationModel.from_pretrained(model)
+ llm_model = _vertex_llm_model_object or TextGenerationModel.from_pretrained(
+ model
+ )
mode = "text"
request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n"
elif model in litellm.vertex_code_text_models:
- llm_model = CodeGenerationModel.from_pretrained(model)
+ llm_model = _vertex_llm_model_object or CodeGenerationModel.from_pretrained(
+ model
+ )
mode = "text"
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models
- llm_model = CodeChatModel.from_pretrained(model)
+ llm_model = _vertex_llm_model_object or CodeChatModel.from_pretrained(model)
mode = "chat"
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
elif model == "private":
@@ -1034,6 +962,15 @@ async def async_completion(
tools=tools,
)
+ _cache_key = _get_client_cache_key(
+ model=model,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ )
+ _set_client_in_cache(
+ client_cache_key=_cache_key, vertex_llm_model=llm_model
+ )
+
if tools is not None and bool(
getattr(response.candidates[0].content.parts[0], "function_call", None)
):
diff --git a/litellm/main.py b/litellm/main.py
index 37fc1db8f..2c906e990 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
-from .llms.bedrock_httpx import BedrockLLM
+from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
@@ -92,6 +92,7 @@ import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict, Union, Mapping
from .caching import enable_cache, disable_cache, update_cache
+from .types.llms.openai import HttpxBinaryResponseContent
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
@@ -121,6 +122,7 @@ huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
+bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
####### COMPLETION ENDPOINTS ################
@@ -223,7 +225,7 @@ async def acompletion(
extra_headers: Optional[dict] = None,
# Optional liteLLM function params
**kwargs,
-):
+) -> Union[ModelResponse, CustomStreamWrapper]:
"""
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@@ -294,6 +296,7 @@ async def acompletion(
"api_version": api_version,
"api_key": api_key,
"model_list": model_list,
+ "extra_headers": extra_headers,
"acompletion": True, # assuming this is a required parameter
}
if custom_llm_provider is None:
@@ -338,6 +341,8 @@ async def acompletion(
if isinstance(init_response, dict) or isinstance(
init_response, ModelResponse
): ## CACHING SCENARIO
+ if isinstance(init_response, dict):
+ response = ModelResponse(**init_response)
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
@@ -360,6 +365,10 @@ async def acompletion(
) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls)
return response
except Exception as e:
+ verbose_logger.error(
+ "litellm.acompletion(): Exception occured - {}".format(str(e))
+ )
+ verbose_logger.debug(traceback.format_exc())
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model,
@@ -423,12 +432,16 @@ def mock_completion(
if isinstance(mock_response, openai.APIError):
raise mock_response
raise litellm.APIError(
- status_code=500, # type: ignore
- message=str(mock_response),
- llm_provider="openai", # type: ignore
+ status_code=getattr(mock_response, "status_code", 500), # type: ignore
+ message=getattr(mock_response, "text", str(mock_response)),
+ llm_provider=getattr(mock_response, "llm_provider", "openai"), # type: ignore
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
+ time_delay = kwargs.get("mock_delay", None)
+ if time_delay is not None:
+ time.sleep(time_delay)
+
model_response = ModelResponse(stream=stream)
if stream is True:
# don't try to access stream object,
@@ -459,16 +472,25 @@ def mock_completion(
try:
_, custom_llm_provider, _, _ = litellm.utils.get_llm_provider(model=model)
model_response._hidden_params["custom_llm_provider"] = custom_llm_provider
- except:
+ except Exception:
# dont let setting a hidden param block a mock_respose
pass
+ if logging is not None:
+ logging.post_call(
+ input=messages,
+ api_key="my-secret-key",
+ original_response="my-original-response",
+ )
return model_response
except Exception as e:
if isinstance(e, openai.APIError):
raise e
- traceback.print_exc()
+ verbose_logger.error(
+ "litellm.mock_completion(): Exception occured - {}".format(str(e))
+ )
+ verbose_logger.debug(traceback.format_exc())
raise Exception("Mock completion response failed")
@@ -679,6 +701,7 @@ def completion(
"region_name",
"allowed_model_region",
"model_config",
+ "fastest_response",
]
default_params = openai_params + litellm_params
@@ -828,6 +851,7 @@ def completion(
logprobs=logprobs,
top_logprobs=top_logprobs,
extra_headers=extra_headers,
+ api_version=api_version,
**non_default_params,
)
@@ -878,6 +902,7 @@ def completion(
mock_response=mock_response,
logging=logging,
acompletion=acompletion,
+ mock_delay=kwargs.get("mock_delay", None),
)
if custom_llm_provider == "azure":
# azure configs
@@ -1924,7 +1949,8 @@ def completion(
)
api_base = (
- optional_params.pop("api_base", None)
+ api_base
+ or optional_params.pop("api_base", None)
or optional_params.pop("base_url", None)
or litellm.api_base
or get_secret("PREDIBASE_API_BASE")
@@ -1952,12 +1978,13 @@ def completion(
custom_prompt_dict=custom_prompt_dict,
api_key=api_key,
tenant_id=tenant_id,
+ timeout=timeout,
)
if (
"stream" in optional_params
- and optional_params["stream"] == True
- and acompletion == False
+ and optional_params["stream"] is True
+ and acompletion is False
):
return _model_response
response = _model_response
@@ -2085,22 +2112,40 @@ def completion(
logging_obj=logging,
)
else:
- response = bedrock_chat_completion.completion(
- model=model,
- messages=messages,
- custom_prompt_dict=custom_prompt_dict,
- model_response=model_response,
- print_verbose=print_verbose,
- optional_params=optional_params,
- litellm_params=litellm_params,
- logger_fn=logger_fn,
- encoding=encoding,
- logging_obj=logging,
- extra_headers=extra_headers,
- timeout=timeout,
- acompletion=acompletion,
- client=client,
- )
+ if model.startswith("anthropic"):
+ response = bedrock_converse_chat_completion.completion(
+ model=model,
+ messages=messages,
+ custom_prompt_dict=custom_prompt_dict,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ encoding=encoding,
+ logging_obj=logging,
+ extra_headers=extra_headers,
+ timeout=timeout,
+ acompletion=acompletion,
+ client=client,
+ )
+ else:
+ response = bedrock_chat_completion.completion(
+ model=model,
+ messages=messages,
+ custom_prompt_dict=custom_prompt_dict,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ encoding=encoding,
+ logging_obj=logging,
+ extra_headers=extra_headers,
+ timeout=timeout,
+ acompletion=acompletion,
+ client=client,
+ )
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
@@ -2403,6 +2448,7 @@ def completion(
"top_k": kwargs.get("top_k", 40),
},
},
+ verify=litellm.ssl_verify,
)
response_json = resp.json()
"""
@@ -3712,7 +3758,7 @@ async def amoderation(input: str, model: str, api_key: Optional[str] = None, **k
##### Image Generation #######################
@client
-async def aimage_generation(*args, **kwargs):
+async def aimage_generation(*args, **kwargs) -> ImageResponse:
"""
Asynchronously calls the `image_generation` function with the given arguments and keyword arguments.
@@ -3745,6 +3791,8 @@ async def aimage_generation(*args, **kwargs):
if isinstance(init_response, dict) or isinstance(
init_response, ImageResponse
): ## CACHING SCENARIO
+ if isinstance(init_response, dict):
+ init_response = ImageResponse(**init_response)
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
@@ -3780,7 +3828,7 @@ def image_generation(
litellm_logging_obj=None,
custom_llm_provider=None,
**kwargs,
-):
+) -> ImageResponse:
"""
Maps the https://api.openai.com/v1/images/generations endpoint.
@@ -4112,7 +4160,7 @@ def transcription(
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_API_KEY")
- )
+ ) # type: ignore
response = azure_chat_completions.audio_transcriptions(
model=model,
@@ -4129,6 +4177,24 @@ def transcription(
max_retries=max_retries,
)
elif custom_llm_provider == "openai":
+ api_base = (
+ api_base
+ or litellm.api_base
+ or get_secret("OPENAI_API_BASE")
+ or "https://api.openai.com/v1"
+ ) # type: ignore
+ openai.organization = (
+ litellm.organization
+ or get_secret("OPENAI_ORGANIZATION")
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
+ )
+ # set API KEY
+ api_key = (
+ api_key
+ or litellm.api_key
+ or litellm.openai_key
+ or get_secret("OPENAI_API_KEY")
+ ) # type: ignore
response = openai_chat_completions.audio_transcriptions(
model=model,
audio_file=file,
@@ -4138,6 +4204,139 @@ def transcription(
timeout=timeout,
logging_obj=litellm_logging_obj,
max_retries=max_retries,
+ api_base=api_base,
+ api_key=api_key,
+ )
+ return response
+
+
+@client
+async def aspeech(*args, **kwargs) -> HttpxBinaryResponseContent:
+ """
+ Calls openai tts endpoints.
+ """
+ loop = asyncio.get_event_loop()
+ model = args[0] if len(args) > 0 else kwargs["model"]
+ ### PASS ARGS TO Image Generation ###
+ kwargs["aspeech"] = True
+ custom_llm_provider = kwargs.get("custom_llm_provider", None)
+ try:
+ # Use a partial function to pass your keyword arguments
+ func = partial(speech, *args, **kwargs)
+
+ # Add the context to the function
+ ctx = contextvars.copy_context()
+ func_with_context = partial(ctx.run, func)
+
+ _, custom_llm_provider, _, _ = get_llm_provider(
+ model=model, api_base=kwargs.get("api_base", None)
+ )
+
+ # Await normally
+ init_response = await loop.run_in_executor(None, func_with_context)
+ if asyncio.iscoroutine(init_response):
+ response = await init_response
+ else:
+ # Call the synchronous function using run_in_executor
+ response = await loop.run_in_executor(None, func_with_context)
+ return response # type: ignore
+ except Exception as e:
+ custom_llm_provider = custom_llm_provider or "openai"
+ raise exception_type(
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ original_exception=e,
+ completion_kwargs=args,
+ extra_kwargs=kwargs,
+ )
+
+
+@client
+def speech(
+ model: str,
+ input: str,
+ voice: str,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ organization: Optional[str] = None,
+ project: Optional[str] = None,
+ max_retries: Optional[int] = None,
+ metadata: Optional[dict] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ response_format: Optional[str] = None,
+ speed: Optional[int] = None,
+ client=None,
+ headers: Optional[dict] = None,
+ custom_llm_provider: Optional[str] = None,
+ aspeech: Optional[bool] = None,
+ **kwargs,
+) -> HttpxBinaryResponseContent:
+
+ model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
+
+ optional_params = {}
+ if response_format is not None:
+ optional_params["response_format"] = response_format
+ if speed is not None:
+ optional_params["speed"] = speed # type: ignore
+
+ if timeout is None:
+ timeout = litellm.request_timeout
+
+ if max_retries is None:
+ max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
+ response: Optional[HttpxBinaryResponseContent] = None
+ if custom_llm_provider == "openai":
+ api_base = (
+ api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
+ or litellm.api_base
+ or get_secret("OPENAI_API_BASE")
+ or "https://api.openai.com/v1"
+ ) # type: ignore
+ # set API KEY
+ api_key = (
+ api_key
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
+ or litellm.openai_key
+ or get_secret("OPENAI_API_KEY")
+ ) # type: ignore
+
+ organization = (
+ organization
+ or litellm.organization
+ or get_secret("OPENAI_ORGANIZATION")
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
+ ) # type: ignore
+
+ project = (
+ project
+ or litellm.project
+ or get_secret("OPENAI_PROJECT")
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
+ ) # type: ignore
+
+ headers = headers or litellm.headers
+
+ response = openai_chat_completions.audio_speech(
+ model=model,
+ input=input,
+ voice=voice,
+ optional_params=optional_params,
+ api_key=api_key,
+ api_base=api_base,
+ organization=organization,
+ project=project,
+ max_retries=max_retries,
+ timeout=timeout,
+ client=client, # pass AsyncOpenAI, OpenAI client
+ aspeech=aspeech,
+ )
+
+ if response is None:
+ raise Exception(
+ "Unable to map the custom llm provider={} to a known provider={}.".format(
+ custom_llm_provider, litellm.provider_list
+ )
)
return response
@@ -4170,6 +4369,10 @@ async def ahealth_check(
mode = litellm.model_cost[model]["mode"]
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
+
+ if model in litellm.model_cost and mode is None:
+ mode = litellm.model_cost[model]["mode"]
+
mode = mode or "chat" # default to chat completion calls
if custom_llm_provider == "azure":
@@ -4260,7 +4463,10 @@ async def ahealth_check(
response = {} # args like remaining ratelimit etc.
return response
except Exception as e:
- traceback.print_exc()
+ verbose_logger.error(
+ "litellm.ahealth_check(): Exception occured - {}".format(str(e))
+ )
+ verbose_logger.debug(traceback.format_exc())
stack_trace = traceback.format_exc()
if isinstance(stack_trace, str):
stack_trace = stack_trace[:1000]
@@ -4366,7 +4572,7 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
def stream_chunk_builder(
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
-):
+) -> Union[ModelResponse, TextCompletionResponse]:
model_response = litellm.ModelResponse()
### SORT CHUNKS BASED ON CREATED ORDER ##
print_verbose("Goes into checking if chunk has hiddden created at param")
diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json
index aab9c9af1..f2b292c92 100644
--- a/litellm/model_prices_and_context_window_backup.json
+++ b/litellm/model_prices_and_context_window_backup.json
@@ -380,6 +380,18 @@
"output_cost_per_second": 0.0001,
"litellm_provider": "azure"
},
+ "azure/gpt-4o": {
+ "max_tokens": 4096,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.000005,
+ "output_cost_per_token": 0.000015,
+ "litellm_provider": "azure",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_parallel_function_calling": true,
+ "supports_vision": true
+ },
"azure/gpt-4-turbo-2024-04-09": {
"max_tokens": 4096,
"max_input_tokens": 128000,
@@ -692,8 +704,8 @@
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
- "input_cost_per_token": 0.00000015,
- "output_cost_per_token": 0.00000046,
+ "input_cost_per_token": 0.00000025,
+ "output_cost_per_token": 0.00000025,
"litellm_provider": "mistral",
"mode": "chat"
},
@@ -701,8 +713,8 @@
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
- "input_cost_per_token": 0.000002,
- "output_cost_per_token": 0.000006,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000003,
"litellm_provider": "mistral",
"supports_function_calling": true,
"mode": "chat"
@@ -711,8 +723,8 @@
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
- "input_cost_per_token": 0.000002,
- "output_cost_per_token": 0.000006,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000003,
"litellm_provider": "mistral",
"supports_function_calling": true,
"mode": "chat"
@@ -748,8 +760,8 @@
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
- "input_cost_per_token": 0.000008,
- "output_cost_per_token": 0.000024,
+ "input_cost_per_token": 0.000004,
+ "output_cost_per_token": 0.000012,
"litellm_provider": "mistral",
"mode": "chat",
"supports_function_calling": true
@@ -758,26 +770,63 @@
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
- "input_cost_per_token": 0.000008,
- "output_cost_per_token": 0.000024,
+ "input_cost_per_token": 0.000004,
+ "output_cost_per_token": 0.000012,
"litellm_provider": "mistral",
"mode": "chat",
"supports_function_calling": true
},
+ "mistral/open-mistral-7b": {
+ "max_tokens": 8191,
+ "max_input_tokens": 32000,
+ "max_output_tokens": 8191,
+ "input_cost_per_token": 0.00000025,
+ "output_cost_per_token": 0.00000025,
+ "litellm_provider": "mistral",
+ "mode": "chat"
+ },
"mistral/open-mixtral-8x7b": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
+ "input_cost_per_token": 0.0000007,
+ "output_cost_per_token": 0.0000007,
+ "litellm_provider": "mistral",
+ "mode": "chat",
+ "supports_function_calling": true
+ },
+ "mistral/open-mixtral-8x22b": {
+ "max_tokens": 8191,
+ "max_input_tokens": 64000,
+ "max_output_tokens": 8191,
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000006,
"litellm_provider": "mistral",
"mode": "chat",
"supports_function_calling": true
},
+ "mistral/codestral-latest": {
+ "max_tokens": 8191,
+ "max_input_tokens": 32000,
+ "max_output_tokens": 8191,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000003,
+ "litellm_provider": "mistral",
+ "mode": "chat"
+ },
+ "mistral/codestral-2405": {
+ "max_tokens": 8191,
+ "max_input_tokens": 32000,
+ "max_output_tokens": 8191,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000003,
+ "litellm_provider": "mistral",
+ "mode": "chat"
+ },
"mistral/mistral-embed": {
"max_tokens": 8192,
"max_input_tokens": 8192,
- "input_cost_per_token": 0.000000111,
+ "input_cost_per_token": 0.0000001,
"litellm_provider": "mistral",
"mode": "embedding"
},
@@ -1128,6 +1177,24 @@
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
+ "gemini-1.5-flash-001": {
+ "max_tokens": 8192,
+ "max_input_tokens": 1000000,
+ "max_output_tokens": 8192,
+ "max_images_per_prompt": 3000,
+ "max_videos_per_prompt": 10,
+ "max_video_length": 1,
+ "max_audio_length_hours": 8.4,
+ "max_audio_per_prompt": 1,
+ "max_pdf_size_mb": 30,
+ "input_cost_per_token": 0,
+ "output_cost_per_token": 0,
+ "litellm_provider": "vertex_ai-language-models",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_vision": true,
+ "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
+ },
"gemini-1.5-flash-preview-0514": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
@@ -1146,6 +1213,18 @@
"supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
+ "gemini-1.5-pro-001": {
+ "max_tokens": 8192,
+ "max_input_tokens": 1000000,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.000000625,
+ "output_cost_per_token": 0.000001875,
+ "litellm_provider": "vertex_ai-language-models",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": true,
+ "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
+ },
"gemini-1.5-pro-preview-0514": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
@@ -1265,8 +1344,8 @@
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
- "input_cost_per_token": 0.0000015,
- "output_cost_per_token": 0.0000075,
+ "input_cost_per_token": 0.000015,
+ "output_cost_per_token": 0.000075,
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat",
"supports_function_calling": true,
@@ -1421,7 +1500,7 @@
"max_pdf_size_mb": 30,
"input_cost_per_token": 0,
"output_cost_per_token": 0,
- "litellm_provider": "vertex_ai-language-models",
+ "litellm_provider": "gemini",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
@@ -2930,32 +3009,37 @@
"litellm_provider": "sagemaker",
"mode": "chat"
},
- "together-ai-up-to-3b": {
+ "together-ai-up-to-4b": {
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.0000001,
"litellm_provider": "together_ai"
},
- "together-ai-3.1b-7b": {
+ "together-ai-4.1b-8b": {
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002,
"litellm_provider": "together_ai"
},
- "together-ai-7.1b-20b": {
+ "together-ai-8.1b-21b": {
"max_tokens": 1000,
- "input_cost_per_token": 0.0000004,
- "output_cost_per_token": 0.0000004,
+ "input_cost_per_token": 0.0000003,
+ "output_cost_per_token": 0.0000003,
"litellm_provider": "together_ai"
},
- "together-ai-20.1b-40b": {
+ "together-ai-21.1b-41b": {
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.0000008,
"litellm_provider": "together_ai"
},
- "together-ai-40.1b-70b": {
+ "together-ai-41.1b-80b": {
"input_cost_per_token": 0.0000009,
"output_cost_per_token": 0.0000009,
"litellm_provider": "together_ai"
},
+ "together-ai-81.1b-110b": {
+ "input_cost_per_token": 0.0000018,
+ "output_cost_per_token": 0.0000018,
+ "litellm_provider": "together_ai"
+ },
"together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1": {
"input_cost_per_token": 0.0000006,
"output_cost_per_token": 0.0000006,
diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html
index 14787d256..fc813d761 100644
--- a/litellm/proxy/_experimental/out/404.html
+++ b/litellm/proxy/_experimental/out/404.html
@@ -1 +1 @@
-404: This page could not be found.LiteLLM Dashboard
404
This page could not be found.
\ No newline at end of file
+404: This page could not be found.LiteLLM Dashboard