+
+
+ If you have any questions, please send an email to {email_support_contact}
+
+ Best,
+ The LiteLLM team
+"""
diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py
index f4a581eb9..4d580f666 100644
--- a/litellm/integrations/langfuse.py
+++ b/litellm/integrations/langfuse.py
@@ -93,6 +93,7 @@ class LangFuseLogger:
)
litellm_params = kwargs.get("litellm_params", {})
+ litellm_call_id = kwargs.get("litellm_call_id", None)
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
@@ -161,6 +162,7 @@ class LangFuseLogger:
response_obj,
level,
print_verbose,
+ litellm_call_id,
)
elif response_obj is not None:
self._log_langfuse_v1(
@@ -255,6 +257,7 @@ class LangFuseLogger:
response_obj,
level,
print_verbose,
+ litellm_call_id,
) -> tuple:
import langfuse
@@ -318,7 +321,7 @@ class LangFuseLogger:
session_id = clean_metadata.pop("session_id", None)
trace_name = clean_metadata.pop("trace_name", None)
- trace_id = clean_metadata.pop("trace_id", None)
+ trace_id = clean_metadata.pop("trace_id", litellm_call_id)
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
update_trace_keys = clean_metadata.pop("update_trace_keys", [])
debug = clean_metadata.pop("debug_langfuse", None)
@@ -351,9 +354,13 @@ class LangFuseLogger:
# Special keys that are found in the function arguments and not the metadata
if "input" in update_trace_keys:
- trace_params["input"] = input if not mask_input else "redacted-by-litellm"
+ trace_params["input"] = (
+ input if not mask_input else "redacted-by-litellm"
+ )
if "output" in update_trace_keys:
- trace_params["output"] = output if not mask_output else "redacted-by-litellm"
+ trace_params["output"] = (
+ output if not mask_output else "redacted-by-litellm"
+ )
else: # don't overwrite an existing trace
trace_params = {
"id": trace_id,
@@ -375,7 +382,9 @@ class LangFuseLogger:
if level == "ERROR":
trace_params["status_message"] = output
else:
- trace_params["output"] = output if not mask_output else "redacted-by-litellm"
+ trace_params["output"] = (
+ output if not mask_output else "redacted-by-litellm"
+ )
if debug == True or (isinstance(debug, str) and debug.lower() == "true"):
if "metadata" in trace_params:
@@ -387,6 +396,8 @@ class LangFuseLogger:
cost = kwargs.get("response_cost", None)
print_verbose(f"trace: {cost}")
+ clean_metadata["litellm_response_cost"] = cost
+
if (
litellm._langfuse_default_tags is not None
and isinstance(litellm._langfuse_default_tags, list)
@@ -412,7 +423,6 @@ class LangFuseLogger:
if "cache_hit" in kwargs:
if kwargs["cache_hit"] is None:
kwargs["cache_hit"] = False
- tags.append(f"cache_hit:{kwargs['cache_hit']}")
clean_metadata["cache_hit"] = kwargs["cache_hit"]
if existing_trace_id is None:
trace_params.update({"tags": tags})
@@ -447,8 +457,13 @@ class LangFuseLogger:
}
generation_name = clean_metadata.pop("generation_name", None)
if generation_name is None:
- # just log `litellm-{call_type}` as the generation name
+ # if `generation_name` is None, use sensible default values
+ # If using litellm proxy user `key_alias` if not None
+ # If `key_alias` is None, just log `litellm-{call_type}` as the generation name
+ _user_api_key_alias = clean_metadata.get("user_api_key_alias", None)
generation_name = f"litellm-{kwargs.get('call_type', 'completion')}"
+ if _user_api_key_alias is not None:
+ generation_name = f"litellm:{_user_api_key_alias}"
if response_obj is not None and "system_fingerprint" in response_obj:
system_fingerprint = response_obj.get("system_fingerprint", None)
diff --git a/litellm/integrations/logfire_logger.py b/litellm/integrations/logfire_logger.py
new file mode 100644
index 000000000..e27d848fb
--- /dev/null
+++ b/litellm/integrations/logfire_logger.py
@@ -0,0 +1,178 @@
+#### What this does ####
+# On success + failure, log events to Logfire
+
+import dotenv, os
+
+dotenv.load_dotenv() # Loading env variables using dotenv
+import traceback
+import uuid
+from litellm._logging import print_verbose, verbose_logger
+
+from enum import Enum
+from typing import Any, Dict, NamedTuple
+from typing_extensions import LiteralString
+
+
+class SpanConfig(NamedTuple):
+ message_template: LiteralString
+ span_data: Dict[str, Any]
+
+
+class LogfireLevel(str, Enum):
+ INFO = "info"
+ ERROR = "error"
+
+
+class LogfireLogger:
+ # Class variables or attributes
+ def __init__(self):
+ try:
+ verbose_logger.debug(f"in init logfire logger")
+ import logfire
+
+ # only setting up logfire if we are sending to logfire
+ # in testing, we don't want to send to logfire
+ if logfire.DEFAULT_LOGFIRE_INSTANCE.config.send_to_logfire:
+ logfire.configure(token=os.getenv("LOGFIRE_TOKEN"))
+ except Exception as e:
+ print_verbose(f"Got exception on init logfire client {str(e)}")
+ raise e
+
+ def _get_span_config(self, payload) -> SpanConfig:
+ if (
+ payload["call_type"] == "completion"
+ or payload["call_type"] == "acompletion"
+ ):
+ return SpanConfig(
+ message_template="Chat Completion with {request_data[model]!r}",
+ span_data={"request_data": payload},
+ )
+ elif (
+ payload["call_type"] == "embedding" or payload["call_type"] == "aembedding"
+ ):
+ return SpanConfig(
+ message_template="Embedding Creation with {request_data[model]!r}",
+ span_data={"request_data": payload},
+ )
+ elif (
+ payload["call_type"] == "image_generation"
+ or payload["call_type"] == "aimage_generation"
+ ):
+ return SpanConfig(
+ message_template="Image Generation with {request_data[model]!r}",
+ span_data={"request_data": payload},
+ )
+ else:
+ return SpanConfig(
+ message_template="Litellm Call with {request_data[model]!r}",
+ span_data={"request_data": payload},
+ )
+
+ async def _async_log_event(
+ self,
+ kwargs,
+ response_obj,
+ start_time,
+ end_time,
+ print_verbose,
+ level: LogfireLevel,
+ ):
+ self.log_event(
+ kwargs=kwargs,
+ response_obj=response_obj,
+ start_time=start_time,
+ end_time=end_time,
+ print_verbose=print_verbose,
+ level=level,
+ )
+
+ def log_event(
+ self,
+ kwargs,
+ start_time,
+ end_time,
+ print_verbose,
+ level: LogfireLevel,
+ response_obj,
+ ):
+ try:
+ import logfire
+
+ verbose_logger.debug(
+ f"logfire Logging - Enters logging function for model {kwargs}"
+ )
+
+ if not response_obj:
+ response_obj = {}
+ litellm_params = kwargs.get("litellm_params", {})
+ metadata = (
+ litellm_params.get("metadata", {}) or {}
+ ) # if litellm_params['metadata'] == None
+ messages = kwargs.get("messages")
+ optional_params = kwargs.get("optional_params", {})
+ call_type = kwargs.get("call_type", "completion")
+ cache_hit = kwargs.get("cache_hit", False)
+ usage = response_obj.get("usage", {})
+ id = response_obj.get("id", str(uuid.uuid4()))
+ try:
+ response_time = (end_time - start_time).total_seconds()
+ except:
+ response_time = None
+
+ # Clean Metadata before logging - never log raw metadata
+ # the raw metadata can contain circular references which leads to infinite recursion
+ # we clean out all extra litellm metadata params before logging
+ clean_metadata = {}
+ if isinstance(metadata, dict):
+ for key, value in metadata.items():
+ # clean litellm metadata before logging
+ if key in [
+ "endpoint",
+ "caching_groups",
+ "previous_models",
+ ]:
+ continue
+ else:
+ clean_metadata[key] = value
+
+ # Build the initial payload
+ payload = {
+ "id": id,
+ "call_type": call_type,
+ "cache_hit": cache_hit,
+ "startTime": start_time,
+ "endTime": end_time,
+ "responseTime (seconds)": response_time,
+ "model": kwargs.get("model", ""),
+ "user": kwargs.get("user", ""),
+ "modelParameters": optional_params,
+ "spend": kwargs.get("response_cost", 0),
+ "messages": messages,
+ "response": response_obj,
+ "usage": usage,
+ "metadata": clean_metadata,
+ }
+ logfire_openai = logfire.with_settings(custom_scope_suffix="openai")
+ message_template, span_data = self._get_span_config(payload)
+ if level == LogfireLevel.INFO:
+ logfire_openai.info(
+ message_template,
+ **span_data,
+ )
+ elif level == LogfireLevel.ERROR:
+ logfire_openai.error(
+ message_template,
+ **span_data,
+ _exc_info=True,
+ )
+ print_verbose(f"\ndd Logger - Logging payload = {payload}")
+
+ print_verbose(
+ f"Logfire Layer Logging - final response object: {response_obj}"
+ )
+ except Exception as e:
+ traceback.print_exc()
+ verbose_logger.debug(
+ f"Logfire Layer Error - {str(e)}\n{traceback.format_exc()}"
+ )
+ pass
diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py
new file mode 100644
index 000000000..ac92d5ddd
--- /dev/null
+++ b/litellm/integrations/opentelemetry.py
@@ -0,0 +1,197 @@
+import os
+from typing import Optional
+from dataclasses import dataclass
+
+from litellm.integrations.custom_logger import CustomLogger
+from litellm._logging import verbose_logger
+
+LITELLM_TRACER_NAME = "litellm"
+LITELLM_RESOURCE = {"service.name": "litellm"}
+
+
+@dataclass
+class OpenTelemetryConfig:
+ from opentelemetry.sdk.trace.export import SpanExporter
+
+ exporter: str | SpanExporter = "console"
+ endpoint: Optional[str] = None
+ headers: Optional[str] = None
+
+ @classmethod
+ def from_env(cls):
+ """
+ OTEL_HEADERS=x-honeycomb-team=B85YgLm9****
+ OTEL_EXPORTER="otlp_http"
+ OTEL_ENDPOINT="https://api.honeycomb.io/v1/traces"
+
+ OTEL_HEADERS gets sent as headers = {"x-honeycomb-team": "B85YgLm96******"}
+ """
+ return cls(
+ exporter=os.getenv("OTEL_EXPORTER", "console"),
+ endpoint=os.getenv("OTEL_ENDPOINT"),
+ headers=os.getenv(
+ "OTEL_HEADERS"
+ ), # example: OTEL_HEADERS=x-honeycomb-team=B85YgLm96VGdFisfJVme1H"
+ )
+
+
+class OpenTelemetry(CustomLogger):
+ def __init__(self, config=OpenTelemetryConfig.from_env()):
+ from opentelemetry import trace
+ from opentelemetry.sdk.resources import Resource
+ from opentelemetry.sdk.trace import TracerProvider
+
+ self.config = config
+ self.OTEL_EXPORTER = self.config.exporter
+ self.OTEL_ENDPOINT = self.config.endpoint
+ self.OTEL_HEADERS = self.config.headers
+ provider = TracerProvider(resource=Resource(attributes=LITELLM_RESOURCE))
+ provider.add_span_processor(self._get_span_processor())
+
+ trace.set_tracer_provider(provider)
+ self.tracer = trace.get_tracer(LITELLM_TRACER_NAME)
+
+ if bool(os.getenv("DEBUG_OTEL", False)) is True:
+ # Set up logging
+ import logging
+
+ logging.basicConfig(level=logging.DEBUG)
+ logger = logging.getLogger(__name__)
+
+ # Enable OpenTelemetry logging
+ otel_exporter_logger = logging.getLogger("opentelemetry.sdk.trace.export")
+ otel_exporter_logger.setLevel(logging.DEBUG)
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_sucess(kwargs, response_obj, start_time, end_time)
+
+ def log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_sucess(kwargs, response_obj, start_time, end_time)
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ self._handle_failure(kwargs, response_obj, start_time, end_time)
+
+ def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
+ from opentelemetry.trace import Status, StatusCode
+
+ verbose_logger.debug(
+ "OpenTelemetry Logger: Logging kwargs: %s, OTEL config settings=%s",
+ kwargs,
+ self.config,
+ )
+
+ span = self.tracer.start_span(
+ name=self._get_span_name(kwargs),
+ start_time=self._to_ns(start_time),
+ context=self._get_span_context(kwargs),
+ )
+ span.set_status(Status(StatusCode.OK))
+ self.set_attributes(span, kwargs, response_obj)
+ span.end(end_time=self._to_ns(end_time))
+
+ def _handle_failure(self, kwargs, response_obj, start_time, end_time):
+ from opentelemetry.trace import Status, StatusCode
+
+ span = self.tracer.start_span(
+ name=self._get_span_name(kwargs),
+ start_time=self._to_ns(start_time),
+ context=self._get_span_context(kwargs),
+ )
+ span.set_status(Status(StatusCode.ERROR))
+ self.set_attributes(span, kwargs, response_obj)
+ span.end(end_time=self._to_ns(end_time))
+
+ def set_attributes(self, span, kwargs, response_obj):
+ for key in ["model", "api_base", "api_version"]:
+ if key in kwargs:
+ span.set_attribute(key, kwargs[key])
+
+ def _to_ns(self, dt):
+ return int(dt.timestamp() * 1e9)
+
+ def _get_span_name(self, kwargs):
+ return f"litellm-{kwargs.get('call_type', 'completion')}"
+
+ def _get_span_context(self, kwargs):
+ from opentelemetry.trace.propagation.tracecontext import (
+ TraceContextTextMapPropagator,
+ )
+
+ litellm_params = kwargs.get("litellm_params", {}) or {}
+ proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
+ headers = proxy_server_request.get("headers", {}) or {}
+ traceparent = headers.get("traceparent", None)
+
+ if traceparent is None:
+ return None
+ else:
+ carrier = {"traceparent": traceparent}
+ return TraceContextTextMapPropagator().extract(carrier=carrier)
+
+ def _get_span_processor(self):
+ from opentelemetry.sdk.trace.export import (
+ SpanExporter,
+ SimpleSpanProcessor,
+ BatchSpanProcessor,
+ ConsoleSpanExporter,
+ )
+ from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
+ OTLPSpanExporter as OTLPSpanExporterHTTP,
+ )
+ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
+ OTLPSpanExporter as OTLPSpanExporterGRPC,
+ )
+
+ verbose_logger.debug(
+ "OpenTelemetry Logger, initializing span processor \nself.OTEL_EXPORTER: %s\nself.OTEL_ENDPOINT: %s\nself.OTEL_HEADERS: %s",
+ self.OTEL_EXPORTER,
+ self.OTEL_ENDPOINT,
+ self.OTEL_HEADERS,
+ )
+ _split_otel_headers = {}
+ if self.OTEL_HEADERS is not None and isinstance(self.OTEL_HEADERS, str):
+ _split_otel_headers = self.OTEL_HEADERS.split("=")
+ _split_otel_headers = {_split_otel_headers[0]: _split_otel_headers[1]}
+
+ if isinstance(self.OTEL_EXPORTER, SpanExporter):
+ verbose_logger.debug(
+ "OpenTelemetry: intiializing SpanExporter. Value of OTEL_EXPORTER: %s",
+ self.OTEL_EXPORTER,
+ )
+ return SimpleSpanProcessor(self.OTEL_EXPORTER)
+
+ if self.OTEL_EXPORTER == "console":
+ verbose_logger.debug(
+ "OpenTelemetry: intiializing console exporter. Value of OTEL_EXPORTER: %s",
+ self.OTEL_EXPORTER,
+ )
+ return BatchSpanProcessor(ConsoleSpanExporter())
+ elif self.OTEL_EXPORTER == "otlp_http":
+ verbose_logger.debug(
+ "OpenTelemetry: intiializing http exporter. Value of OTEL_EXPORTER: %s",
+ self.OTEL_EXPORTER,
+ )
+ return BatchSpanProcessor(
+ OTLPSpanExporterHTTP(
+ endpoint=self.OTEL_ENDPOINT, headers=_split_otel_headers
+ )
+ )
+ elif self.OTEL_EXPORTER == "otlp_grpc":
+ verbose_logger.debug(
+ "OpenTelemetry: intiializing grpc exporter. Value of OTEL_EXPORTER: %s",
+ self.OTEL_EXPORTER,
+ )
+ return BatchSpanProcessor(
+ OTLPSpanExporterGRPC(
+ endpoint=self.OTEL_ENDPOINT, headers=_split_otel_headers
+ )
+ )
+ else:
+ verbose_logger.debug(
+ "OpenTelemetry: intiializing console exporter. Value of OTEL_EXPORTER: %s",
+ self.OTEL_EXPORTER,
+ )
+ return BatchSpanProcessor(ConsoleSpanExporter())
diff --git a/litellm/integrations/slack_alerting.py b/litellm/integrations/slack_alerting.py
index 015278c55..5ed92af0a 100644
--- a/litellm/integrations/slack_alerting.py
+++ b/litellm/integrations/slack_alerting.py
@@ -1,20 +1,48 @@
#### What this does ####
# Class for sending Slack Alerts #
-import dotenv, os
-from litellm.proxy._types import UserAPIKeyAuth
+import dotenv, os, traceback
+from litellm.proxy._types import UserAPIKeyAuth, CallInfo, AlertType
from litellm._logging import verbose_logger, verbose_proxy_logger
import litellm, threading
-from typing import List, Literal, Any, Union, Optional, Dict
+from typing import List, Literal, Any, Union, Optional, Dict, Set
from litellm.caching import DualCache
-import asyncio
+import asyncio, time
import aiohttp
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
import datetime
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
from enum import Enum
from datetime import datetime as dt, timedelta, timezone
from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import WebhookEvent
import random
+from typing import TypedDict
+from openai import APIError
+from .email_templates.templates import *
+
+import litellm.types
+from litellm.types.router import LiteLLM_Params
+
+
+class BaseOutageModel(TypedDict):
+ alerts: List[int]
+ minor_alert_sent: bool
+ major_alert_sent: bool
+ last_updated_at: float
+
+
+class OutageModel(BaseOutageModel):
+ model_id: str
+
+
+class ProviderRegionOutageModel(BaseOutageModel):
+ provider_region_id: str
+ deployment_ids: Set[str]
+
+
+# we use this for the email header, please send a test email if you change this. verify it looks good on email
+LITELLM_LOGO_URL = "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
+LITELLM_SUPPORT_CONTACT = "support@berri.ai"
class LiteLLMBase(BaseModel):
@@ -30,12 +58,55 @@ class LiteLLMBase(BaseModel):
return self.dict()
+class SlackAlertingArgsEnum(Enum):
+ daily_report_frequency: int = 12 * 60 * 60
+ report_check_interval: int = 5 * 60
+ budget_alert_ttl: int = 24 * 60 * 60
+ outage_alert_ttl: int = 1 * 60
+ region_outage_alert_ttl: int = 1 * 60
+ minor_outage_alert_threshold: int = 1 * 5
+ major_outage_alert_threshold: int = 1 * 10
+ max_outage_alert_list_size: int = 1 * 10
+
+
class SlackAlertingArgs(LiteLLMBase):
- default_daily_report_frequency: int = 12 * 60 * 60 # 12 hours
- daily_report_frequency: int = int(
- os.getenv("SLACK_DAILY_REPORT_FREQUENCY", default_daily_report_frequency)
+ daily_report_frequency: int = Field(
+ default=int(
+ os.getenv(
+ "SLACK_DAILY_REPORT_FREQUENCY",
+ SlackAlertingArgsEnum.daily_report_frequency.value,
+ )
+ ),
+ description="Frequency of receiving deployment latency/failure reports. Default is 12hours. Value is in seconds.",
)
- report_check_interval: int = 5 * 60 # 5 minutes
+ report_check_interval: int = Field(
+ default=SlackAlertingArgsEnum.report_check_interval.value,
+ description="Frequency of checking cache if report should be sent. Background process. Default is once per hour. Value is in seconds.",
+ ) # 5 minutes
+ budget_alert_ttl: int = Field(
+ default=SlackAlertingArgsEnum.budget_alert_ttl.value,
+ description="Cache ttl for budgets alerts. Prevents spamming same alert, each time budget is crossed. Value is in seconds.",
+ ) # 24 hours
+ outage_alert_ttl: int = Field(
+ default=SlackAlertingArgsEnum.outage_alert_ttl.value,
+ description="Cache ttl for model outage alerts. Sets time-window for errors. Default is 1 minute. Value is in seconds.",
+ ) # 1 minute ttl
+ region_outage_alert_ttl: int = Field(
+ default=SlackAlertingArgsEnum.region_outage_alert_ttl.value,
+ description="Cache ttl for provider-region based outage alerts. Alert sent if 2+ models in same region report errors. Sets time-window for errors. Default is 1 minute. Value is in seconds.",
+ ) # 1 minute ttl
+ minor_outage_alert_threshold: int = Field(
+ default=SlackAlertingArgsEnum.minor_outage_alert_threshold.value,
+ description="The number of errors that count as a model/region minor outage. ('400' error code is not counted).",
+ )
+ major_outage_alert_threshold: int = Field(
+ default=SlackAlertingArgsEnum.major_outage_alert_threshold.value,
+ description="The number of errors that countas a model/region major outage. ('400' error code is not counted).",
+ )
+ max_outage_alert_list_size: int = Field(
+ default=SlackAlertingArgsEnum.max_outage_alert_list_size.value,
+ description="Maximum number of errors to store in cache. For a given model/region. Prevents memory leaks.",
+ ) # prevent memory leak
class DeploymentMetrics(LiteLLMBase):
@@ -79,19 +150,7 @@ class SlackAlerting(CustomLogger):
internal_usage_cache: Optional[DualCache] = None,
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
alerting: Optional[List] = [],
- alert_types: List[
- Literal[
- "llm_exceptions",
- "llm_too_slow",
- "llm_requests_hanging",
- "budget_alerts",
- "db_exceptions",
- "daily_reports",
- "spend_reports",
- "cooldown_deployment",
- "new_model_added",
- ]
- ] = [
+ alert_types: List[AlertType] = [
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
@@ -101,6 +160,7 @@ class SlackAlerting(CustomLogger):
"spend_reports",
"cooldown_deployment",
"new_model_added",
+ "outage_alerts",
],
alert_to_webhook_url: Optional[
Dict
@@ -117,6 +177,7 @@ class SlackAlerting(CustomLogger):
self.is_running = False
self.alerting_args = SlackAlertingArgs(**alerting_args)
self.default_webhook_url = default_webhook_url
+ self.llm_router: Optional[litellm.Router] = None
def update_values(
self,
@@ -125,6 +186,7 @@ class SlackAlerting(CustomLogger):
alert_types: Optional[List] = None,
alert_to_webhook_url: Optional[Dict] = None,
alerting_args: Optional[Dict] = None,
+ llm_router: Optional[litellm.Router] = None,
):
if alerting is not None:
self.alerting = alerting
@@ -140,6 +202,8 @@ class SlackAlerting(CustomLogger):
self.alert_to_webhook_url = alert_to_webhook_url
else:
self.alert_to_webhook_url.update(alert_to_webhook_url)
+ if llm_router is not None:
+ self.llm_router = llm_router
async def deployment_in_cooldown(self):
pass
@@ -164,13 +228,28 @@ class SlackAlerting(CustomLogger):
) -> Optional[str]:
"""
Returns langfuse trace url
+
+ - check:
+ -> existing_trace_id
+ -> trace_id
+ -> litellm_call_id
"""
# do nothing for now
- if (
- request_data is not None
- and request_data.get("metadata", {}).get("trace_id", None) is not None
- ):
- trace_id = request_data["metadata"]["trace_id"]
+ if request_data is not None:
+ trace_id = None
+ if (
+ request_data.get("metadata", {}).get("existing_trace_id", None)
+ is not None
+ ):
+ trace_id = request_data["metadata"]["existing_trace_id"]
+ elif request_data.get("metadata", {}).get("trace_id", None) is not None:
+ trace_id = request_data["metadata"]["trace_id"]
+ elif request_data.get("litellm_logging_obj", None) is not None and hasattr(
+ request_data["litellm_logging_obj"], "model_call_details"
+ ):
+ trace_id = request_data["litellm_logging_obj"].model_call_details[
+ "litellm_call_id"
+ ]
if litellm.utils.langFuseLogger is not None:
base_url = litellm.utils.langFuseLogger.Langfuse.base_url
return f"{base_url}/trace/{trace_id}"
@@ -353,6 +432,9 @@ class SlackAlerting(CustomLogger):
keys=combined_metrics_keys
) # [1, 2, None, ..]
+ if combined_metrics_values is None:
+ return False
+
all_none = True
for val in combined_metrics_values:
if val is not None and val > 0:
@@ -404,7 +486,7 @@ class SlackAlerting(CustomLogger):
]
# format alert -> return the litellm model name + api base
- message = f"\n\nHere are today's key metrics 📈: \n\n"
+ message = f"\n\nTime: `{time.time()}`s\nHere are today's key metrics 📈: \n\n"
message += "\n\n*❗️ Top Deployments with Most Failed Requests:*\n\n"
if not top_5_failed:
@@ -455,6 +537,8 @@ class SlackAlerting(CustomLogger):
cache_list=combined_metrics_cache_keys
)
+ message += f"\n\nNext Run is at: `{time.time() + self.alerting_args.daily_report_frequency}`s"
+
# send alert
await self.send_alert(message=message, level="Low", alert_type="daily_reports")
@@ -555,127 +639,468 @@ class SlackAlerting(CustomLogger):
alert_type="llm_requests_hanging",
)
+ async def failed_tracking_alert(self, error_message: str):
+ """Raise alert when tracking failed for specific model"""
+ _cache: DualCache = self.internal_usage_cache
+ message = "Failed Tracking Cost for" + error_message
+ _cache_key = "budget_alerts:failed_tracking:{}".format(message)
+ result = await _cache.async_get_cache(key=_cache_key)
+ if result is None:
+ await self.send_alert(
+ message=message, level="High", alert_type="budget_alerts"
+ )
+ await _cache.async_set_cache(
+ key=_cache_key,
+ value="SENT",
+ ttl=self.alerting_args.budget_alert_ttl,
+ )
+
async def budget_alerts(
self,
type: Literal[
"token_budget",
"user_budget",
- "user_and_proxy_budget",
- "failed_budgets",
- "failed_tracking",
+ "team_budget",
+ "proxy_budget",
"projected_limit_exceeded",
],
- user_max_budget: float,
- user_current_spend: float,
- user_info=None,
- error_message="",
+ user_info: CallInfo,
):
+ ## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
+ # - Alert once within 24hr period
+ # - Cache this information
+ # - Don't re-alert, if alert already sent
+ _cache: DualCache = self.internal_usage_cache
+
if self.alerting is None or self.alert_types is None:
# do nothing if alerting is not switched on
return
if "budget_alerts" not in self.alert_types:
return
_id: str = "default_id" # used for caching
- if type == "user_and_proxy_budget":
- user_info = dict(user_info)
- user_id = user_info["user_id"]
- _id = user_id
- max_budget = user_info["max_budget"]
- spend = user_info["spend"]
- user_email = user_info["user_email"]
- user_info = f"""\nUser ID: {user_id}\nMax Budget: ${max_budget}\nSpend: ${spend}\nUser Email: {user_email}"""
+ user_info_json = user_info.model_dump(exclude_none=True)
+ for k, v in user_info_json.items():
+ user_info_str = "\n{}: {}\n".format(k, v)
+
+ event: Optional[
+ Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]
+ ] = None
+ event_group: Optional[
+ Literal["internal_user", "team", "key", "proxy", "customer"]
+ ] = None
+ event_message: str = ""
+ webhook_event: Optional[WebhookEvent] = None
+ if type == "proxy_budget":
+ event_group = "proxy"
+ event_message += "Proxy Budget: "
+ elif type == "user_budget":
+ event_group = "internal_user"
+ event_message += "User Budget: "
+ _id = user_info.user_id or _id
+ elif type == "team_budget":
+ event_group = "team"
+ event_message += "Team Budget: "
+ _id = user_info.team_id or _id
elif type == "token_budget":
- token_info = dict(user_info)
- token = token_info["token"]
- _id = token
- spend = token_info["spend"]
- max_budget = token_info["max_budget"]
- user_id = token_info["user_id"]
- user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
- elif type == "failed_tracking":
- user_id = str(user_info)
- _id = user_id
- user_info = f"\nUser ID: {user_id}\n Error {error_message}"
- message = "Failed Tracking Cost for" + user_info
- await self.send_alert(
- message=message, level="High", alert_type="budget_alerts"
- )
- return
- elif type == "projected_limit_exceeded" and user_info is not None:
- """
- Input variables:
- user_info = {
- "key_alias": key_alias,
- "projected_spend": projected_spend,
- "projected_exceeded_date": projected_exceeded_date,
- }
- user_max_budget=soft_limit,
- user_current_spend=new_spend
- """
- message = f"""\n🚨 `ProjectedLimitExceededError` 💸\n\n`Key Alias:` {user_info["key_alias"]} \n`Expected Day of Error`: {user_info["projected_exceeded_date"]} \n`Current Spend`: {user_current_spend} \n`Projected Spend at end of month`: {user_info["projected_spend"]} \n`Soft Limit`: {user_max_budget}"""
- await self.send_alert(
- message=message, level="High", alert_type="budget_alerts"
- )
- return
- else:
- user_info = str(user_info)
+ event_group = "key"
+ event_message += "Key Budget: "
+ _id = user_info.token
+ elif type == "projected_limit_exceeded":
+ event_group = "key"
+ event_message += "Key Budget: Projected Limit Exceeded"
+ event = "projected_limit_exceeded"
+ _id = user_info.token
# percent of max_budget left to spend
- if user_max_budget > 0:
- percent_left = (user_max_budget - user_current_spend) / user_max_budget
+ if user_info.max_budget is None:
+ return
+
+ if user_info.max_budget > 0:
+ percent_left = (
+ user_info.max_budget - user_info.spend
+ ) / user_info.max_budget
else:
percent_left = 0
- verbose_proxy_logger.debug(
- f"Budget Alerts: Percent left: {percent_left} for {user_info}"
- )
-
- ## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
- # - Alert once within 28d period
- # - Cache this information
- # - Don't re-alert, if alert already sent
- _cache: DualCache = self.internal_usage_cache
# check if crossed budget
- if user_current_spend >= user_max_budget:
- verbose_proxy_logger.debug("Budget Crossed for %s", user_info)
- message = "Budget Crossed for" + user_info
- result = await _cache.async_get_cache(key=message)
- if result is None:
- await self.send_alert(
- message=message, level="High", alert_type="budget_alerts"
- )
- await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
- return
+ if user_info.spend >= user_info.max_budget:
+ event = "budget_crossed"
+ event_message += f"Budget Crossed\n Total Budget:`{user_info.max_budget}`"
+ elif percent_left <= 0.05:
+ event = "threshold_crossed"
+ event_message += "5% Threshold Crossed "
+ elif percent_left <= 0.15:
+ event = "threshold_crossed"
+ event_message += "15% Threshold Crossed"
- # check if 5% of max budget is left
- if percent_left <= 0.05:
- message = "5% budget left for" + user_info
- cache_key = "alerting:{}".format(_id)
- result = await _cache.async_get_cache(key=cache_key)
+ if event is not None and event_group is not None:
+ _cache_key = "budget_alerts:{}:{}".format(event, _id)
+ result = await _cache.async_get_cache(key=_cache_key)
if result is None:
+ webhook_event = WebhookEvent(
+ event=event,
+ event_group=event_group,
+ event_message=event_message,
+ **user_info_json,
+ )
await self.send_alert(
- message=message, level="Medium", alert_type="budget_alerts"
+ message=event_message + "\n\n" + user_info_str,
+ level="High",
+ alert_type="budget_alerts",
+ user_info=webhook_event,
+ )
+ await _cache.async_set_cache(
+ key=_cache_key,
+ value="SENT",
+ ttl=self.alerting_args.budget_alert_ttl,
)
- await _cache.async_set_cache(key=cache_key, value="SENT", ttl=2419200)
-
return
-
- # check if 15% of max budget is left
- if percent_left <= 0.15:
- message = "15% budget left for" + user_info
- result = await _cache.async_get_cache(key=message)
- if result is None:
- await self.send_alert(
- message=message, level="Low", alert_type="budget_alerts"
- )
- await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
- return
-
return
- async def model_added_alert(self, model_name: str, litellm_model_name: str):
- model_info = litellm.model_cost.get(litellm_model_name, {})
+ async def customer_spend_alert(
+ self,
+ token: Optional[str],
+ key_alias: Optional[str],
+ end_user_id: Optional[str],
+ response_cost: Optional[float],
+ max_budget: Optional[float],
+ ):
+ if end_user_id is not None and token is not None and response_cost is not None:
+ # log customer spend
+ event = WebhookEvent(
+ spend=response_cost,
+ max_budget=max_budget,
+ token=token,
+ customer_id=end_user_id,
+ user_id=None,
+ team_id=None,
+ user_email=None,
+ key_alias=key_alias,
+ projected_exceeded_date=None,
+ projected_spend=None,
+ event="spend_tracked",
+ event_group="customer",
+ event_message="Customer spend tracked. Customer={}, spend={}".format(
+ end_user_id, response_cost
+ ),
+ )
+
+ await self.send_webhook_alert(webhook_event=event)
+
+ def _count_outage_alerts(self, alerts: List[int]) -> str:
+ """
+ Parameters:
+ - alerts: List[int] -> list of error codes (either 408 or 500+)
+
+ Returns:
+ - str -> formatted string. This is an alert message, giving a human-friendly description of the errors.
+ """
+ error_breakdown = {"Timeout Errors": 0, "API Errors": 0, "Unknown Errors": 0}
+ for alert in alerts:
+ if alert == 408:
+ error_breakdown["Timeout Errors"] += 1
+ elif alert >= 500:
+ error_breakdown["API Errors"] += 1
+ else:
+ error_breakdown["Unknown Errors"] += 1
+
+ error_msg = ""
+ for key, value in error_breakdown.items():
+ if value > 0:
+ error_msg += "\n{}: {}\n".format(key, value)
+
+ return error_msg
+
+ def _outage_alert_msg_factory(
+ self,
+ alert_type: Literal["Major", "Minor"],
+ key: Literal["Model", "Region"],
+ key_val: str,
+ provider: str,
+ api_base: Optional[str],
+ outage_value: BaseOutageModel,
+ ) -> str:
+ """Format an alert message for slack"""
+ headers = {f"{key} Name": key_val, "Provider": provider}
+ if api_base is not None:
+ headers["API Base"] = api_base # type: ignore
+
+ headers_str = "\n"
+ for k, v in headers.items():
+ headers_str += f"*{k}:* `{v}`\n"
+ return f"""\n\n
+*⚠️ {alert_type} Service Outage*
+
+{headers_str}
+
+*Errors:*
+{self._count_outage_alerts(alerts=outage_value["alerts"])}
+
+*Last Check:* `{round(time.time() - outage_value["last_updated_at"], 4)}s ago`\n\n
+"""
+
+ async def region_outage_alerts(
+ self,
+ exception: APIError,
+ deployment_id: str,
+ ) -> None:
+ """
+ Send slack alert if specific provider region is having an outage.
+
+ Track for 408 (Timeout) and >=500 Error codes
+ """
+ ## CREATE (PROVIDER+REGION) ID ##
+ if self.llm_router is None:
+ return
+
+ deployment = self.llm_router.get_deployment(model_id=deployment_id)
+
+ if deployment is None:
+ return
+
+ model = deployment.litellm_params.model
+ ### GET PROVIDER ###
+ provider = deployment.litellm_params.custom_llm_provider
+ if provider is None:
+ model, provider, _, _ = litellm.get_llm_provider(model=model)
+
+ ### GET REGION ###
+ region_name = deployment.litellm_params.region_name
+ if region_name is None:
+ region_name = litellm.utils._get_model_region(
+ custom_llm_provider=provider, litellm_params=deployment.litellm_params
+ )
+
+ if region_name is None:
+ return
+
+ ### UNIQUE CACHE KEY ###
+ cache_key = provider + region_name
+
+ outage_value: Optional[ProviderRegionOutageModel] = (
+ await self.internal_usage_cache.async_get_cache(key=cache_key)
+ )
+
+ if (
+ getattr(exception, "status_code", None) is None
+ or (
+ exception.status_code != 408 # type: ignore
+ and exception.status_code < 500 # type: ignore
+ )
+ or self.llm_router is None
+ ):
+ return
+
+ if outage_value is None:
+ _deployment_set = set()
+ _deployment_set.add(deployment_id)
+ outage_value = ProviderRegionOutageModel(
+ provider_region_id=cache_key,
+ alerts=[exception.status_code], # type: ignore
+ minor_alert_sent=False,
+ major_alert_sent=False,
+ last_updated_at=time.time(),
+ deployment_ids=_deployment_set,
+ )
+
+ ## add to cache ##
+ await self.internal_usage_cache.async_set_cache(
+ key=cache_key,
+ value=outage_value,
+ ttl=self.alerting_args.region_outage_alert_ttl,
+ )
+ return
+
+ if len(outage_value["alerts"]) < self.alerting_args.max_outage_alert_list_size:
+ outage_value["alerts"].append(exception.status_code) # type: ignore
+ else: # prevent memory leaks
+ pass
+ _deployment_set = outage_value["deployment_ids"]
+ _deployment_set.add(deployment_id)
+ outage_value["deployment_ids"] = _deployment_set
+ outage_value["last_updated_at"] = time.time()
+
+ ## MINOR OUTAGE ALERT SENT ##
+ if (
+ outage_value["minor_alert_sent"] == False
+ and len(outage_value["alerts"])
+ >= self.alerting_args.minor_outage_alert_threshold
+ and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
+ ):
+ msg = self._outage_alert_msg_factory(
+ alert_type="Minor",
+ key="Region",
+ key_val=region_name,
+ api_base=None,
+ outage_value=outage_value,
+ provider=provider,
+ )
+ # send minor alert
+ await self.send_alert(
+ message=msg, level="Medium", alert_type="outage_alerts"
+ )
+ # set to true
+ outage_value["minor_alert_sent"] = True
+
+ ## MAJOR OUTAGE ALERT SENT ##
+ elif (
+ outage_value["major_alert_sent"] == False
+ and len(outage_value["alerts"])
+ >= self.alerting_args.major_outage_alert_threshold
+ and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
+ ):
+ msg = self._outage_alert_msg_factory(
+ alert_type="Major",
+ key="Region",
+ key_val=region_name,
+ api_base=None,
+ outage_value=outage_value,
+ provider=provider,
+ )
+
+ # send minor alert
+ await self.send_alert(message=msg, level="High", alert_type="outage_alerts")
+ # set to true
+ outage_value["major_alert_sent"] = True
+
+ ## update cache ##
+ await self.internal_usage_cache.async_set_cache(
+ key=cache_key, value=outage_value
+ )
+
+ async def outage_alerts(
+ self,
+ exception: APIError,
+ deployment_id: str,
+ ) -> None:
+ """
+ Send slack alert if model is badly configured / having an outage (408, 401, 429, >=500).
+
+ key = model_id
+
+ value = {
+ - model_id
+ - threshold
+ - alerts []
+ }
+
+ ttl = 1hr
+ max_alerts_size = 10
+ """
+ try:
+ outage_value: Optional[OutageModel] = await self.internal_usage_cache.async_get_cache(key=deployment_id) # type: ignore
+ if (
+ getattr(exception, "status_code", None) is None
+ or (
+ exception.status_code != 408 # type: ignore
+ and exception.status_code < 500 # type: ignore
+ )
+ or self.llm_router is None
+ ):
+ return
+
+ ### EXTRACT MODEL DETAILS ###
+ deployment = self.llm_router.get_deployment(model_id=deployment_id)
+ if deployment is None:
+ return
+
+ model = deployment.litellm_params.model
+ provider = deployment.litellm_params.custom_llm_provider
+ if provider is None:
+ try:
+ model, provider, _, _ = litellm.get_llm_provider(model=model)
+ except Exception as e:
+ provider = ""
+ api_base = litellm.get_api_base(
+ model=model, optional_params=deployment.litellm_params
+ )
+
+ if outage_value is None:
+ outage_value = OutageModel(
+ model_id=deployment_id,
+ alerts=[exception.status_code], # type: ignore
+ minor_alert_sent=False,
+ major_alert_sent=False,
+ last_updated_at=time.time(),
+ )
+
+ ## add to cache ##
+ await self.internal_usage_cache.async_set_cache(
+ key=deployment_id,
+ value=outage_value,
+ ttl=self.alerting_args.outage_alert_ttl,
+ )
+ return
+
+ if (
+ len(outage_value["alerts"])
+ < self.alerting_args.max_outage_alert_list_size
+ ):
+ outage_value["alerts"].append(exception.status_code) # type: ignore
+ else: # prevent memory leaks
+ pass
+
+ outage_value["last_updated_at"] = time.time()
+
+ ## MINOR OUTAGE ALERT SENT ##
+ if (
+ outage_value["minor_alert_sent"] == False
+ and len(outage_value["alerts"])
+ >= self.alerting_args.minor_outage_alert_threshold
+ ):
+ msg = self._outage_alert_msg_factory(
+ alert_type="Minor",
+ key="Model",
+ key_val=model,
+ api_base=api_base,
+ outage_value=outage_value,
+ provider=provider,
+ )
+ # send minor alert
+ await self.send_alert(
+ message=msg, level="Medium", alert_type="outage_alerts"
+ )
+ # set to true
+ outage_value["minor_alert_sent"] = True
+ elif (
+ outage_value["major_alert_sent"] == False
+ and len(outage_value["alerts"])
+ >= self.alerting_args.major_outage_alert_threshold
+ ):
+ msg = self._outage_alert_msg_factory(
+ alert_type="Major",
+ key="Model",
+ key_val=model,
+ api_base=api_base,
+ outage_value=outage_value,
+ provider=provider,
+ )
+ # send minor alert
+ await self.send_alert(
+ message=msg, level="High", alert_type="outage_alerts"
+ )
+ # set to true
+ outage_value["major_alert_sent"] = True
+
+ ## update cache ##
+ await self.internal_usage_cache.async_set_cache(
+ key=deployment_id, value=outage_value
+ )
+ except Exception as e:
+ pass
+
+ async def model_added_alert(
+ self, model_name: str, litellm_model_name: str, passed_model_info: Any
+ ):
+ base_model_from_user = getattr(passed_model_info, "base_model", None)
+ model_info = {}
+ base_model = ""
+ if base_model_from_user is not None:
+ model_info = litellm.model_cost.get(base_model_from_user, {})
+ base_model = f"Base Model: `{base_model_from_user}`\n"
+ else:
+ model_info = litellm.model_cost.get(litellm_model_name, {})
model_info_str = ""
for k, v in model_info.items():
if k == "input_cost_per_token" or k == "output_cost_per_token":
@@ -687,6 +1112,7 @@ class SlackAlerting(CustomLogger):
message = f"""
*🚅 New Model Added*
Model Name: `{model_name}`
+{base_model}
Usage OpenAI Python SDK:
```
@@ -713,29 +1139,229 @@ Model Info:
```
"""
- await self.send_alert(
+ alert_val = self.send_alert(
message=message, level="Low", alert_type="new_model_added"
)
- pass
+
+ if alert_val is not None and asyncio.iscoroutine(alert_val):
+ await alert_val
async def model_removed_alert(self, model_name: str):
pass
+ async def send_webhook_alert(self, webhook_event: WebhookEvent) -> bool:
+ """
+ Sends structured alert to webhook, if set.
+
+ Currently only implemented for budget alerts
+
+ Returns -> True if sent, False if not.
+ """
+
+ webhook_url = os.getenv("WEBHOOK_URL", None)
+ if webhook_url is None:
+ raise Exception("Missing webhook_url from environment")
+
+ payload = webhook_event.model_dump_json()
+ headers = {"Content-type": "application/json"}
+
+ response = await self.async_http_handler.post(
+ url=webhook_url,
+ headers=headers,
+ data=payload,
+ )
+ if response.status_code == 200:
+ return True
+ else:
+ print("Error sending webhook alert. Error=", response.text) # noqa
+
+ return False
+
+ async def _check_if_using_premium_email_feature(
+ self,
+ premium_user: bool,
+ email_logo_url: Optional[str] = None,
+ email_support_contact: Optional[str] = None,
+ ):
+ from litellm.proxy.proxy_server import premium_user
+ from litellm.proxy.proxy_server import CommonProxyErrors
+
+ if premium_user is not True:
+ if email_logo_url is not None or email_support_contact is not None:
+ raise ValueError(
+ f"Trying to Customize Email Alerting\n {CommonProxyErrors.not_premium_user.value}"
+ )
+ return
+
+ async def send_key_created_or_user_invited_email(
+ self, webhook_event: WebhookEvent
+ ) -> bool:
+ try:
+ from litellm.proxy.utils import send_email
+
+ if self.alerting is None or "email" not in self.alerting:
+ # do nothing if user does not want email alerts
+ return False
+ from litellm.proxy.proxy_server import premium_user, prisma_client
+
+ email_logo_url = os.getenv("SMTP_SENDER_LOGO", None)
+ email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
+ await self._check_if_using_premium_email_feature(
+ premium_user, email_logo_url, email_support_contact
+ )
+ if email_logo_url is None:
+ email_logo_url = LITELLM_LOGO_URL
+ if email_support_contact is None:
+ email_support_contact = LITELLM_SUPPORT_CONTACT
+
+ event_name = webhook_event.event_message
+ recipient_email = webhook_event.user_email
+ recipient_user_id = webhook_event.user_id
+ if (
+ recipient_email is None
+ and recipient_user_id is not None
+ and prisma_client is not None
+ ):
+ user_row = await prisma_client.db.litellm_usertable.find_unique(
+ where={"user_id": recipient_user_id}
+ )
+
+ if user_row is not None:
+ recipient_email = user_row.user_email
+
+ key_name = webhook_event.key_alias
+ key_token = webhook_event.token
+ key_budget = webhook_event.max_budget
+ base_url = os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")
+
+ email_html_content = "Alert from LiteLLM Server"
+ if recipient_email is None:
+ verbose_proxy_logger.error(
+ "Trying to send email alert to no recipient",
+ extra=webhook_event.dict(),
+ )
+
+ if webhook_event.event == "key_created":
+ email_html_content = KEY_CREATED_EMAIL_TEMPLATE.format(
+ email_logo_url=email_logo_url,
+ recipient_email=recipient_email,
+ key_budget=key_budget,
+ key_token=key_token,
+ base_url=base_url,
+ email_support_contact=email_support_contact,
+ )
+ elif webhook_event.event == "internal_user_created":
+ # GET TEAM NAME
+ team_id = webhook_event.team_id
+ team_name = "Default Team"
+ if team_id is not None and prisma_client is not None:
+ team_row = await prisma_client.db.litellm_teamtable.find_unique(
+ where={"team_id": team_id}
+ )
+ if team_row is not None:
+ team_name = team_row.team_alias or "-"
+ email_html_content = USER_INVITED_EMAIL_TEMPLATE.format(
+ email_logo_url=email_logo_url,
+ recipient_email=recipient_email,
+ team_name=team_name,
+ base_url=base_url,
+ email_support_contact=email_support_contact,
+ )
+ else:
+ verbose_proxy_logger.error(
+ "Trying to send email alert on unknown webhook event",
+ extra=webhook_event.model_dump(),
+ )
+
+ payload = webhook_event.model_dump_json()
+ email_event = {
+ "to": recipient_email,
+ "subject": f"LiteLLM: {event_name}",
+ "html": email_html_content,
+ }
+
+ response = await send_email(
+ receiver_email=email_event["to"],
+ subject=email_event["subject"],
+ html=email_event["html"],
+ )
+
+ return True
+
+ except Exception as e:
+ verbose_proxy_logger.error("Error sending email alert %s", str(e))
+ return False
+
+ async def send_email_alert_using_smtp(self, webhook_event: WebhookEvent) -> bool:
+ """
+ Sends structured Email alert to an SMTP server
+
+ Currently only implemented for budget alerts
+
+ Returns -> True if sent, False if not.
+ """
+ from litellm.proxy.utils import send_email
+
+ from litellm.proxy.proxy_server import premium_user, prisma_client
+
+ email_logo_url = os.getenv("SMTP_SENDER_LOGO", None)
+ email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
+ await self._check_if_using_premium_email_feature(
+ premium_user, email_logo_url, email_support_contact
+ )
+
+ if email_logo_url is None:
+ email_logo_url = LITELLM_LOGO_URL
+ if email_support_contact is None:
+ email_support_contact = LITELLM_SUPPORT_CONTACT
+
+ event_name = webhook_event.event_message
+ recipient_email = webhook_event.user_email
+ user_name = webhook_event.user_id
+ max_budget = webhook_event.max_budget
+ email_html_content = "Alert from LiteLLM Server"
+ if recipient_email is None:
+ verbose_proxy_logger.error(
+ "Trying to send email alert to no recipient", extra=webhook_event.dict()
+ )
+
+ if webhook_event.event == "budget_crossed":
+ email_html_content = f"""
+
+
+
Hi {user_name},
+
+ Your LLM API usage this month has reached your account's monthly budget of ${max_budget}
+
+ API requests will be rejected until either (a) you increase your monthly budget or (b) your monthly usage resets at the beginning of the next calendar month.
+
+ If you have any questions, please send an email to {email_support_contact}
+
+ Best,
+ The LiteLLM team
+ """
+
+ payload = webhook_event.model_dump_json()
+ email_event = {
+ "to": recipient_email,
+ "subject": f"LiteLLM: {event_name}",
+ "html": email_html_content,
+ }
+
+ response = await send_email(
+ receiver_email=email_event["to"],
+ subject=email_event["subject"],
+ html=email_event["html"],
+ )
+
+ return False
+
async def send_alert(
self,
message: str,
level: Literal["Low", "Medium", "High"],
- alert_type: Literal[
- "llm_exceptions",
- "llm_too_slow",
- "llm_requests_hanging",
- "budget_alerts",
- "db_exceptions",
- "daily_reports",
- "spend_reports",
- "new_model_added",
- "cooldown_deployment",
- ],
+ alert_type: Literal[AlertType],
+ user_info: Optional[WebhookEvent] = None,
**kwargs,
):
"""
@@ -755,6 +1381,24 @@ Model Info:
if self.alerting is None:
return
+ if (
+ "webhook" in self.alerting
+ and alert_type == "budget_alerts"
+ and user_info is not None
+ ):
+ await self.send_webhook_alert(webhook_event=user_info)
+
+ if (
+ "email" in self.alerting
+ and alert_type == "budget_alerts"
+ and user_info is not None
+ ):
+ # only send budget alerts over Email
+ await self.send_email_alert_using_smtp(webhook_event=user_info)
+
+ if "slack" not in self.alerting:
+ return
+
if alert_type not in self.alert_types:
return
@@ -801,46 +1445,78 @@ Model Info:
if response.status_code == 200:
pass
else:
- print("Error sending slack alert. Error=", response.text) # noqa
+ verbose_proxy_logger.debug(
+ "Error sending slack alert. Error=", response.text
+ )
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""Log deployment latency"""
- if "daily_reports" in self.alert_types:
- model_id = (
- kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
- )
- response_s: timedelta = end_time - start_time
-
- final_value = response_s
- total_tokens = 0
-
- if isinstance(response_obj, litellm.ModelResponse):
- completion_tokens = response_obj.usage.completion_tokens
- final_value = float(response_s.total_seconds() / completion_tokens)
-
- await self.async_update_daily_reports(
- DeploymentMetrics(
- id=model_id,
- failed_request=False,
- latency_per_output_token=final_value,
- updated_at=litellm.utils.get_utc_datetime(),
+ try:
+ if "daily_reports" in self.alert_types:
+ model_id = (
+ kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
)
+ response_s: timedelta = end_time - start_time
+
+ final_value = response_s
+ total_tokens = 0
+
+ if isinstance(response_obj, litellm.ModelResponse):
+ completion_tokens = response_obj.usage.completion_tokens
+ if completion_tokens is not None and completion_tokens > 0:
+ final_value = float(
+ response_s.total_seconds() / completion_tokens
+ )
+ if isinstance(final_value, timedelta):
+ final_value = final_value.total_seconds()
+
+ await self.async_update_daily_reports(
+ DeploymentMetrics(
+ id=model_id,
+ failed_request=False,
+ latency_per_output_token=final_value,
+ updated_at=litellm.utils.get_utc_datetime(),
+ )
+ )
+ except Exception as e:
+ verbose_proxy_logger.error(
+ "[Non-Blocking Error] Slack Alerting: Got error in logging LLM deployment latency: ",
+ e,
)
+ pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""Log failure + deployment latency"""
- if "daily_reports" in self.alert_types:
- model_id = (
- kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
- )
- await self.async_update_daily_reports(
- DeploymentMetrics(
- id=model_id,
- failed_request=True,
- latency_per_output_token=None,
- updated_at=litellm.utils.get_utc_datetime(),
- )
- )
+ _litellm_params = kwargs.get("litellm_params", {})
+ _model_info = _litellm_params.get("model_info", {}) or {}
+ model_id = _model_info.get("id", "")
+ try:
+ if "daily_reports" in self.alert_types:
+ try:
+ await self.async_update_daily_reports(
+ DeploymentMetrics(
+ id=model_id,
+ failed_request=True,
+ latency_per_output_token=None,
+ updated_at=litellm.utils.get_utc_datetime(),
+ )
+ )
+ except Exception as e:
+ verbose_logger.debug(f"Exception raises -{str(e)}")
+
+ if isinstance(kwargs.get("exception", ""), APIError):
+ if "outage_alerts" in self.alert_types:
+ await self.outage_alerts(
+ exception=kwargs["exception"],
+ deployment_id=model_id,
+ )
+
+ if "region_outage_alerts" in self.alert_types:
+ await self.region_outage_alerts(
+ exception=kwargs["exception"], deployment_id=model_id
+ )
+ except Exception as e:
+ pass
async def _run_scheduler_helper(self, llm_router) -> bool:
"""
@@ -852,40 +1528,26 @@ Model Info:
report_sent = await self.internal_usage_cache.async_get_cache(
key=SlackAlertingCacheKeys.report_sent_key.value
- ) # None | datetime
+ ) # None | float
- current_time = litellm.utils.get_utc_datetime()
+ current_time = time.time()
if report_sent is None:
- _current_time = current_time.isoformat()
await self.internal_usage_cache.async_set_cache(
key=SlackAlertingCacheKeys.report_sent_key.value,
- value=_current_time,
+ value=current_time,
)
- else:
+ elif isinstance(report_sent, float):
# Check if current time - interval >= time last sent
- delta_naive = timedelta(seconds=self.alerting_args.daily_report_frequency)
- if isinstance(report_sent, str):
- report_sent = dt.fromisoformat(report_sent)
+ interval_seconds = self.alerting_args.daily_report_frequency
- # Ensure report_sent is an aware datetime object
- if report_sent.tzinfo is None:
- report_sent = report_sent.replace(tzinfo=timezone.utc)
-
- # Calculate delta as an aware datetime object with the same timezone as report_sent
- delta = report_sent - delta_naive
-
- current_time_utc = current_time.astimezone(timezone.utc)
- delta_utc = delta.astimezone(timezone.utc)
-
- if current_time_utc >= delta_utc:
+ if current_time - report_sent >= interval_seconds:
# Sneak in the reporting logic here
await self.send_daily_reports(router=llm_router)
# Also, don't forget to update the report_sent time after sending the report!
- _current_time = current_time.isoformat()
await self.internal_usage_cache.async_set_cache(
key=SlackAlertingCacheKeys.report_sent_key.value,
- value=_current_time,
+ value=current_time,
)
report_sent_bool = True
diff --git a/litellm/integrations/traceloop.py b/litellm/integrations/traceloop.py
index bbdb9a1b0..e1c419c6f 100644
--- a/litellm/integrations/traceloop.py
+++ b/litellm/integrations/traceloop.py
@@ -1,114 +1,149 @@
+import traceback
+from litellm._logging import verbose_logger
+import litellm
+
+
class TraceloopLogger:
def __init__(self):
- from traceloop.sdk.tracing.tracing import TracerWrapper
- from traceloop.sdk import Traceloop
+ try:
+ from traceloop.sdk.tracing.tracing import TracerWrapper
+ from traceloop.sdk import Traceloop
+ from traceloop.sdk.instruments import Instruments
+ from opentelemetry.sdk.trace.export import ConsoleSpanExporter
+ except ModuleNotFoundError as e:
+ verbose_logger.error(
+ f"Traceloop not installed, try running 'pip install traceloop-sdk' to fix this error: {e}\n{traceback.format_exc()}"
+ )
- Traceloop.init(app_name="Litellm-Server", disable_batch=True)
+ Traceloop.init(
+ app_name="Litellm-Server",
+ disable_batch=True,
+ )
self.tracer_wrapper = TracerWrapper()
- def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
- from opentelemetry.trace import SpanKind
+ def log_event(
+ self,
+ kwargs,
+ response_obj,
+ start_time,
+ end_time,
+ user_id,
+ print_verbose,
+ level="DEFAULT",
+ status_message=None,
+ ):
+ from opentelemetry import trace
+ from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.semconv.ai import SpanAttributes
try:
+ print_verbose(
+ f"Traceloop Logging - Enters logging function for model {kwargs}"
+ )
+
tracer = self.tracer_wrapper.get_tracer()
- model = kwargs.get("model")
-
- # LiteLLM uses the standard OpenAI library, so it's already handled by Traceloop SDK
- if kwargs.get("litellm_params").get("custom_llm_provider") == "openai":
- return
-
optional_params = kwargs.get("optional_params", {})
- with tracer.start_as_current_span(
- "litellm.completion",
- kind=SpanKind.CLIENT,
- ) as span:
- if span.is_recording():
+ start_time = int(start_time.timestamp())
+ end_time = int(end_time.timestamp())
+ span = tracer.start_span(
+ "litellm.completion", kind=SpanKind.CLIENT, start_time=start_time
+ )
+
+ if span.is_recording():
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_MODEL, kwargs.get("model")
+ )
+ if "stop" in optional_params:
span.set_attribute(
- SpanAttributes.LLM_REQUEST_MODEL, kwargs.get("model")
+ SpanAttributes.LLM_CHAT_STOP_SEQUENCES,
+ optional_params.get("stop"),
)
- if "stop" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_CHAT_STOP_SEQUENCES,
- optional_params.get("stop"),
- )
- if "frequency_penalty" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_FREQUENCY_PENALTY,
- optional_params.get("frequency_penalty"),
- )
- if "presence_penalty" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_PRESENCE_PENALTY,
- optional_params.get("presence_penalty"),
- )
- if "top_p" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_TOP_P, optional_params.get("top_p")
- )
- if "tools" in optional_params or "functions" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_REQUEST_FUNCTIONS,
- optional_params.get(
- "tools", optional_params.get("functions")
- ),
- )
- if "user" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_USER, optional_params.get("user")
- )
- if "max_tokens" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_REQUEST_MAX_TOKENS,
- kwargs.get("max_tokens"),
- )
- if "temperature" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_TEMPERATURE, kwargs.get("temperature")
- )
-
- for idx, prompt in enumerate(kwargs.get("messages")):
- span.set_attribute(
- f"{SpanAttributes.LLM_PROMPTS}.{idx}.role",
- prompt.get("role"),
- )
- span.set_attribute(
- f"{SpanAttributes.LLM_PROMPTS}.{idx}.content",
- prompt.get("content"),
- )
-
+ if "frequency_penalty" in optional_params:
span.set_attribute(
- SpanAttributes.LLM_RESPONSE_MODEL, response_obj.get("model")
+ SpanAttributes.LLM_FREQUENCY_PENALTY,
+ optional_params.get("frequency_penalty"),
+ )
+ if "presence_penalty" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_PRESENCE_PENALTY,
+ optional_params.get("presence_penalty"),
+ )
+ if "top_p" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_TOP_P, optional_params.get("top_p")
+ )
+ if "tools" in optional_params or "functions" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_FUNCTIONS,
+ optional_params.get("tools", optional_params.get("functions")),
+ )
+ if "user" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_USER, optional_params.get("user")
+ )
+ if "max_tokens" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_MAX_TOKENS,
+ kwargs.get("max_tokens"),
+ )
+ if "temperature" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_TEMPERATURE,
+ kwargs.get("temperature"),
)
- usage = response_obj.get("usage")
- if usage:
- span.set_attribute(
- SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
- usage.get("total_tokens"),
- )
- span.set_attribute(
- SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
- usage.get("completion_tokens"),
- )
- span.set_attribute(
- SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
- usage.get("prompt_tokens"),
- )
- for idx, choice in enumerate(response_obj.get("choices")):
- span.set_attribute(
- f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.finish_reason",
- choice.get("finish_reason"),
- )
- span.set_attribute(
- f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.role",
- choice.get("message").get("role"),
- )
- span.set_attribute(
- f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.content",
- choice.get("message").get("content"),
- )
+ for idx, prompt in enumerate(kwargs.get("messages")):
+ span.set_attribute(
+ f"{SpanAttributes.LLM_PROMPTS}.{idx}.role",
+ prompt.get("role"),
+ )
+ span.set_attribute(
+ f"{SpanAttributes.LLM_PROMPTS}.{idx}.content",
+ prompt.get("content"),
+ )
+
+ span.set_attribute(
+ SpanAttributes.LLM_RESPONSE_MODEL, response_obj.get("model")
+ )
+ usage = response_obj.get("usage")
+ if usage:
+ span.set_attribute(
+ SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
+ usage.get("total_tokens"),
+ )
+ span.set_attribute(
+ SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
+ usage.get("completion_tokens"),
+ )
+ span.set_attribute(
+ SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
+ usage.get("prompt_tokens"),
+ )
+
+ for idx, choice in enumerate(response_obj.get("choices")):
+ span.set_attribute(
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.finish_reason",
+ choice.get("finish_reason"),
+ )
+ span.set_attribute(
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.role",
+ choice.get("message").get("role"),
+ )
+ span.set_attribute(
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.content",
+ choice.get("message").get("content"),
+ )
+
+ if (
+ level == "ERROR"
+ and status_message is not None
+ and isinstance(status_message, str)
+ ):
+ span.record_exception(Exception(status_message))
+ span.set_status(Status(StatusCode.ERROR, status_message))
+
+ span.end(end_time)
except Exception as e:
print_verbose(f"Traceloop Layer Error - {e}")
diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py
index f14dabc03..8e469a8f4 100644
--- a/litellm/llms/anthropic.py
+++ b/litellm/llms/anthropic.py
@@ -3,6 +3,7 @@ import json
from enum import Enum
import requests, copy # type: ignore
import time
+from functools import partial
from typing import Callable, Optional, List, Union
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
@@ -10,6 +11,7 @@ from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
+from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
class AnthropicConstants(Enum):
@@ -102,6 +104,17 @@ class AnthropicConfig:
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
+ if param == "tool_choice":
+ _tool_choice: Optional[AnthropicMessagesToolChoice] = None
+ if value == "auto":
+ _tool_choice = {"type": "auto"}
+ elif value == "required":
+ _tool_choice = {"type": "any"}
+ elif isinstance(value, dict):
+ _tool_choice = {"type": "tool", "name": value["function"]["name"]}
+
+ if _tool_choice is not None:
+ optional_params["tool_choice"] = _tool_choice
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "stop":
@@ -148,6 +161,36 @@ def validate_environment(api_key, user_headers):
return headers
+async def make_call(
+ client: Optional[AsyncHTTPHandler],
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+):
+ if client is None:
+ client = AsyncHTTPHandler() # Create a new client if none provided
+
+ response = await client.post(api_base, headers=headers, data=data, stream=True)
+
+ if response.status_code != 200:
+ raise AnthropicError(status_code=response.status_code, message=response.text)
+
+ completion_stream = response.aiter_lines()
+
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=completion_stream, # Pass the completion stream for logging
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
@@ -367,23 +410,34 @@ class AnthropicChatCompletion(BaseLLM):
logger_fn=None,
headers={},
):
- self.async_handler = AsyncHTTPHandler(
- timeout=httpx.Timeout(timeout=600.0, connect=5.0)
- )
data["stream"] = True
- response = await self.async_handler.post(
- api_base, headers=headers, data=json.dumps(data), stream=True
- )
+ # async_handler = AsyncHTTPHandler(
+ # timeout=httpx.Timeout(timeout=600.0, connect=20.0)
+ # )
- if response.status_code != 200:
- raise AnthropicError(
- status_code=response.status_code, message=response.text
- )
+ # response = await async_handler.post(
+ # api_base, headers=headers, json=data, stream=True
+ # )
- completion_stream = response.aiter_lines()
+ # if response.status_code != 200:
+ # raise AnthropicError(
+ # status_code=response.status_code, message=response.text
+ # )
+
+ # completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
- completion_stream=completion_stream,
+ completion_stream=None,
+ make_call=partial(
+ make_call,
+ client=None,
+ api_base=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ ),
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
@@ -409,12 +463,10 @@ class AnthropicChatCompletion(BaseLLM):
logger_fn=None,
headers={},
) -> Union[ModelResponse, CustomStreamWrapper]:
- self.async_handler = AsyncHTTPHandler(
+ async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
- response = await self.async_handler.post(
- api_base, headers=headers, data=json.dumps(data)
- )
+ response = await async_handler.post(api_base, headers=headers, json=data)
if stream and _is_function_call:
return self.process_streaming_response(
model=model,
diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py
index 02fe4a08f..834fcbea9 100644
--- a/litellm/llms/azure.py
+++ b/litellm/llms/azure.py
@@ -1,4 +1,5 @@
-from typing import Optional, Union, Any, Literal
+from typing import Optional, Union, Any, Literal, Coroutine, Iterable
+from typing_extensions import overload
import types, requests
from .base import BaseLLM
from litellm.utils import (
@@ -9,6 +10,7 @@ from litellm.utils import (
convert_to_model_response_object,
TranscriptionResponse,
get_secret,
+ UnsupportedParamsError,
)
from typing import Callable, Optional, BinaryIO, List
from litellm import OpenAIConfig
@@ -18,6 +20,22 @@ from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTra
from openai import AzureOpenAI, AsyncAzureOpenAI
import uuid
import os
+from ..types.llms.openai import (
+ AsyncCursorPage,
+ AssistantToolParam,
+ SyncCursorPage,
+ Assistant,
+ MessageData,
+ OpenAIMessage,
+ OpenAICreateThreadParamsMessage,
+ Thread,
+ AssistantToolParam,
+ Run,
+ AssistantEventHandler,
+ AsyncAssistantEventHandler,
+ AsyncAssistantStreamManager,
+ AssistantStreamManager,
+)
class AzureOpenAIError(Exception):
@@ -45,9 +63,9 @@ class AzureOpenAIError(Exception):
) # Call the base class constructor with the parameters it needs
-class AzureOpenAIConfig(OpenAIConfig):
+class AzureOpenAIConfig:
"""
- Reference: https://platform.openai.com/docs/api-reference/chat/create
+ Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
@@ -85,18 +103,111 @@ class AzureOpenAIConfig(OpenAIConfig):
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
- super().__init__(
- frequency_penalty,
- function_call,
- functions,
- logit_bias,
- max_tokens,
- n,
- presence_penalty,
- stop,
- temperature,
- top_p,
- )
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_supported_openai_params(self):
+ return [
+ "temperature",
+ "n",
+ "stream",
+ "stop",
+ "max_tokens",
+ "tools",
+ "tool_choice",
+ "presence_penalty",
+ "frequency_penalty",
+ "logit_bias",
+ "user",
+ "function_call",
+ "functions",
+ "tools",
+ "tool_choice",
+ "top_p",
+ "logprobs",
+ "top_logprobs",
+ "response_format",
+ "seed",
+ "extra_headers",
+ ]
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ api_version: str, # Y-M-D-{optional}
+ drop_params,
+ ) -> dict:
+ supported_openai_params = self.get_supported_openai_params()
+
+ api_version_times = api_version.split("-")
+ api_version_year = api_version_times[0]
+ api_version_month = api_version_times[1]
+ api_version_day = api_version_times[2]
+ for param, value in non_default_params.items():
+ if param == "tool_choice":
+ """
+ This parameter requires API version 2023-12-01-preview or later
+
+ tool_choice='required' is not supported as of 2024-05-01-preview
+ """
+ ## check if api version supports this param ##
+ if (
+ api_version_year < "2023"
+ or (api_version_year == "2023" and api_version_month < "12")
+ or (
+ api_version_year == "2023"
+ and api_version_month == "12"
+ and api_version_day < "01"
+ )
+ ):
+ if litellm.drop_params == True or (
+ drop_params is not None and drop_params == True
+ ):
+ pass
+ else:
+ raise UnsupportedParamsError(
+ status_code=400,
+ message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""",
+ )
+ elif value == "required" and (
+ api_version_year == "2024" and api_version_month <= "05"
+ ): ## check if tool_choice value is supported ##
+ if litellm.drop_params == True or (
+ drop_params is not None and drop_params == True
+ ):
+ pass
+ else:
+ raise UnsupportedParamsError(
+ status_code=400,
+ message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions",
+ )
+ else:
+ optional_params["tool_choice"] = value
+ elif param in supported_openai_params:
+ optional_params[param] = value
+ return optional_params
def get_mapped_special_auth_params(self) -> dict:
return {"token": "azure_ad_token"}
@@ -114,6 +225,68 @@ class AzureOpenAIConfig(OpenAIConfig):
return ["europe", "sweden", "switzerland", "france", "uk"]
+class AzureOpenAIAssistantsAPIConfig:
+ """
+ Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
+ """
+
+ def __init__(
+ self,
+ ) -> None:
+ pass
+
+ def get_supported_openai_create_message_params(self):
+ return [
+ "role",
+ "content",
+ "attachments",
+ "metadata",
+ ]
+
+ def map_openai_params_create_message_params(
+ self, non_default_params: dict, optional_params: dict
+ ):
+ for param, value in non_default_params.items():
+ if param == "role":
+ optional_params["role"] = value
+ if param == "metadata":
+ optional_params["metadata"] = value
+ elif param == "content": # only string accepted
+ if isinstance(value, str):
+ optional_params["content"] = value
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="Azure only accepts content as a string.",
+ status_code=400,
+ )
+ elif (
+ param == "attachments"
+ ): # this is a v2 param. Azure currently supports the old 'file_id's param
+ file_ids: List[str] = []
+ if isinstance(value, list):
+ for item in value:
+ if "file_id" in item:
+ file_ids.append(item["file_id"])
+ else:
+ if litellm.drop_params == True:
+ pass
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="Azure doesn't support {}. To drop it from the call, set `litellm.drop_params = True.".format(
+ value
+ ),
+ status_code=400,
+ )
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="Invalid param. attachments should always be a list. Got={}, Expected=List. Raw value={}".format(
+ type(value), value
+ ),
+ status_code=400,
+ )
+ return optional_params
+
+
def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = {
# "api_version": api_version,
@@ -172,9 +345,7 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
possible_azure_ad_token = req_token.json().get("access_token", None)
if possible_azure_ad_token is None:
- raise AzureOpenAIError(
- status_code=422, message="Azure AD Token not returned"
- )
+ raise AzureOpenAIError(status_code=422, message="Azure AD Token not returned")
return possible_azure_ad_token
@@ -245,7 +416,9 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
- azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
+ azure_ad_token = get_azure_ad_token_from_oidc(
+ azure_ad_token
+ )
azure_client_params["azure_ad_token"] = azure_ad_token
@@ -1192,3 +1365,828 @@ class AzureChatCompletion(BaseLLM):
response["x-ms-region"] = completion.headers["x-ms-region"]
return response
+
+
+class AzureAssistantsAPI(BaseLLM):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def get_azure_client(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AzureOpenAI] = None,
+ ) -> AzureOpenAI:
+ received_args = locals()
+ if client is None:
+ data = {}
+ for k, v in received_args.items():
+ if k == "self" or k == "client":
+ pass
+ elif k == "api_base" and v is not None:
+ data["azure_endpoint"] = v
+ elif v is not None:
+ data[k] = v
+ azure_openai_client = AzureOpenAI(**data) # type: ignore
+ else:
+ azure_openai_client = client
+
+ return azure_openai_client
+
+ def async_get_azure_client(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI] = None,
+ ) -> AsyncAzureOpenAI:
+ received_args = locals()
+ if client is None:
+ data = {}
+ for k, v in received_args.items():
+ if k == "self" or k == "client":
+ pass
+ elif k == "api_base" and v is not None:
+ data["azure_endpoint"] = v
+ elif v is not None:
+ data[k] = v
+
+ azure_openai_client = AsyncAzureOpenAI(**data)
+ # azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
+ else:
+ azure_openai_client = client
+
+ return azure_openai_client
+
+ ### ASSISTANTS ###
+
+ async def async_get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ ) -> AsyncCursorPage[Assistant]:
+ azure_openai_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = await azure_openai_client.beta.assistants.list()
+
+ return response
+
+ # fmt: off
+
+ @overload
+ def get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ aget_assistants: Literal[True],
+ ) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
+ ...
+
+ @overload
+ def get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AzureOpenAI],
+ aget_assistants: Optional[Literal[False]],
+ ) -> SyncCursorPage[Assistant]:
+ ...
+
+ # fmt: on
+
+ def get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client=None,
+ aget_assistants=None,
+ ):
+ if aget_assistants is not None and aget_assistants == True:
+ return self.async_get_assistants(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+ azure_openai_client = self.get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ api_version=api_version,
+ )
+
+ response = azure_openai_client.beta.assistants.list()
+
+ return response
+
+ ### MESSAGES ###
+
+ async def a_add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI] = None,
+ ) -> OpenAIMessage:
+ openai_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
+ thread_id, **message_data # type: ignore
+ )
+
+ response_obj: Optional[OpenAIMessage] = None
+ if getattr(thread_message, "status", None) is None:
+ thread_message.status = "completed"
+ response_obj = OpenAIMessage(**thread_message.dict())
+ else:
+ response_obj = OpenAIMessage(**thread_message.dict())
+ return response_obj
+
+ # fmt: off
+
+ @overload
+ def add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ a_add_message: Literal[True],
+ ) -> Coroutine[None, None, OpenAIMessage]:
+ ...
+
+ @overload
+ def add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AzureOpenAI],
+ a_add_message: Optional[Literal[False]],
+ ) -> OpenAIMessage:
+ ...
+
+ # fmt: on
+
+ def add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client=None,
+ a_add_message: Optional[bool] = None,
+ ):
+ if a_add_message is not None and a_add_message == True:
+ return self.a_add_message(
+ thread_id=thread_id,
+ message_data=message_data,
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+ openai_client = self.get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
+ thread_id, **message_data # type: ignore
+ )
+
+ response_obj: Optional[OpenAIMessage] = None
+ if getattr(thread_message, "status", None) is None:
+ thread_message.status = "completed"
+ response_obj = OpenAIMessage(**thread_message.dict())
+ else:
+ response_obj = OpenAIMessage(**thread_message.dict())
+ return response_obj
+
+ async def async_get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI] = None,
+ ) -> AsyncCursorPage[OpenAIMessage]:
+ openai_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
+
+ return response
+
+ # fmt: off
+
+ @overload
+ def get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ aget_messages: Literal[True],
+ ) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
+ ...
+
+ @overload
+ def get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AzureOpenAI],
+ aget_messages: Optional[Literal[False]],
+ ) -> SyncCursorPage[OpenAIMessage]:
+ ...
+
+ # fmt: on
+
+ def get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client=None,
+ aget_messages=None,
+ ):
+ if aget_messages is not None and aget_messages == True:
+ return self.async_get_messages(
+ thread_id=thread_id,
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+ openai_client = self.get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = openai_client.beta.threads.messages.list(thread_id=thread_id)
+
+ return response
+
+ ### THREADS ###
+
+ async def async_create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ ) -> Thread:
+ openai_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ data = {}
+ if messages is not None:
+ data["messages"] = messages # type: ignore
+ if metadata is not None:
+ data["metadata"] = metadata # type: ignore
+
+ message_thread = await openai_client.beta.threads.create(**data) # type: ignore
+
+ return Thread(**message_thread.dict())
+
+ # fmt: off
+
+ @overload
+ def create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ client: Optional[AsyncAzureOpenAI],
+ acreate_thread: Literal[True],
+ ) -> Coroutine[None, None, Thread]:
+ ...
+
+ @overload
+ def create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ client: Optional[AzureOpenAI],
+ acreate_thread: Optional[Literal[False]],
+ ) -> Thread:
+ ...
+
+ # fmt: on
+
+ def create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ client=None,
+ acreate_thread=None,
+ ):
+ """
+ Here's an example:
+ ```
+ from litellm.llms.openai import OpenAIAssistantsAPI, MessageData
+
+ # create thread
+ message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
+ openai_api.create_thread(messages=[message])
+ ```
+ """
+ if acreate_thread is not None and acreate_thread == True:
+ return self.async_create_thread(
+ metadata=metadata,
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ messages=messages,
+ )
+ azure_openai_client = self.get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ data = {}
+ if messages is not None:
+ data["messages"] = messages # type: ignore
+ if metadata is not None:
+ data["metadata"] = metadata # type: ignore
+
+ message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore
+
+ return Thread(**message_thread.dict())
+
+ async def async_get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ ) -> Thread:
+ openai_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
+
+ return Thread(**response.dict())
+
+ # fmt: off
+
+ @overload
+ def get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ aget_thread: Literal[True],
+ ) -> Coroutine[None, None, Thread]:
+ ...
+
+ @overload
+ def get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AzureOpenAI],
+ aget_thread: Optional[Literal[False]],
+ ) -> Thread:
+ ...
+
+ # fmt: on
+
+ def get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client=None,
+ aget_thread=None,
+ ):
+ if aget_thread is not None and aget_thread == True:
+ return self.async_get_thread(
+ thread_id=thread_id,
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+ openai_client = self.get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = openai_client.beta.threads.retrieve(thread_id=thread_id)
+
+ return Thread(**response.dict())
+
+ # def delete_thread(self):
+ # pass
+
+ ### RUNS ###
+
+ async def arun_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ ) -> Run:
+ openai_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ client=client,
+ )
+
+ response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ )
+
+ return response
+
+ def async_run_thread_stream(
+ self,
+ client: AsyncAzureOpenAI,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ tools: Optional[Iterable[AssistantToolParam]],
+ event_handler: Optional[AssistantEventHandler],
+ ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
+ data = {
+ "thread_id": thread_id,
+ "assistant_id": assistant_id,
+ "additional_instructions": additional_instructions,
+ "instructions": instructions,
+ "metadata": metadata,
+ "model": model,
+ "tools": tools,
+ }
+ if event_handler is not None:
+ data["event_handler"] = event_handler
+ return client.beta.threads.runs.stream(**data) # type: ignore
+
+ def run_thread_stream(
+ self,
+ client: AzureOpenAI,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ tools: Optional[Iterable[AssistantToolParam]],
+ event_handler: Optional[AssistantEventHandler],
+ ) -> AssistantStreamManager[AssistantEventHandler]:
+ data = {
+ "thread_id": thread_id,
+ "assistant_id": assistant_id,
+ "additional_instructions": additional_instructions,
+ "instructions": instructions,
+ "metadata": metadata,
+ "model": model,
+ "tools": tools,
+ }
+ if event_handler is not None:
+ data["event_handler"] = event_handler
+ return client.beta.threads.runs.stream(**data) # type: ignore
+
+ # fmt: off
+
+ @overload
+ def run_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AsyncAzureOpenAI],
+ arun_thread: Literal[True],
+ ) -> Coroutine[None, None, Run]:
+ ...
+
+ @overload
+ def run_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client: Optional[AzureOpenAI],
+ arun_thread: Optional[Literal[False]],
+ ) -> Run:
+ ...
+
+ # fmt: on
+
+ def run_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ client=None,
+ arun_thread=None,
+ event_handler: Optional[AssistantEventHandler] = None,
+ ):
+ if arun_thread is not None and arun_thread == True:
+ if stream is not None and stream == True:
+ azure_client = self.async_get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+ return self.async_run_thread_stream(
+ client=azure_client,
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ event_handler=event_handler,
+ )
+ return self.arun_thread(
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ stream=stream,
+ tools=tools,
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+ openai_client = self.get_azure_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ if stream is not None and stream == True:
+ return self.run_thread_stream(
+ client=openai_client,
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ event_handler=event_handler,
+ )
+
+ response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ )
+
+ return response
diff --git a/litellm/llms/base.py b/litellm/llms/base.py
index d940d9471..8c2f5101e 100644
--- a/litellm/llms/base.py
+++ b/litellm/llms/base.py
@@ -21,7 +21,7 @@ class BaseLLM:
messages: list,
print_verbose,
encoding,
- ) -> litellm.utils.ModelResponse:
+ ) -> Union[litellm.utils.ModelResponse, litellm.utils.CustomStreamWrapper]:
"""
Helper function to process the response across sync + async completion calls
"""
diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py
index 1ff3767bd..dbd7e7c69 100644
--- a/litellm/llms/bedrock_httpx.py
+++ b/litellm/llms/bedrock_httpx.py
@@ -1,7 +1,7 @@
# What is this?
## Initial implementation of calling bedrock via httpx client (allows for async calls).
-## V0 - just covers cohere command-r support
-
+## V1 - covers cohere + anthropic claude-3 support
+from functools import partial
import os, types
import json
from enum import Enum
@@ -29,13 +29,22 @@ from litellm.utils import (
get_secret,
Logging,
)
-import litellm
-from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
+import litellm, uuid
+from .prompt_templates.factory import (
+ prompt_factory,
+ custom_prompt,
+ cohere_message_pt,
+ construct_tool_use_system_prompt,
+ extract_between_tags,
+ parse_xml_params,
+ contains_tag,
+)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
import httpx # type: ignore
-from .bedrock import BedrockError, convert_messages_to_prompt
+from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
from litellm.types.llms.bedrock import *
+import urllib.parse
class AmazonCohereChatConfig:
@@ -136,6 +145,37 @@ class AmazonCohereChatConfig:
return optional_params
+async def make_call(
+ client: Optional[AsyncHTTPHandler],
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+):
+ if client is None:
+ client = AsyncHTTPHandler() # Create a new client if none provided
+
+ response = await client.post(api_base, headers=headers, data=data, stream=True)
+
+ if response.status_code != 200:
+ raise BedrockError(status_code=response.status_code, message=response.text)
+
+ decoder = AWSEventStreamDecoder(model=model)
+ completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
+
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=completion_stream, # Pass the completion stream for logging
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
class BedrockLLM(BaseLLM):
"""
Example call
@@ -208,6 +248,7 @@ class BedrockLLM(BaseLLM):
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
+ aws_web_identity_token: Optional[str] = None,
):
"""
Return a boto3.Credentials object
@@ -222,6 +263,7 @@ class BedrockLLM(BaseLLM):
aws_session_name,
aws_profile_name,
aws_role_name,
+ aws_web_identity_token,
]
# Iterate over parameters and update if needed
@@ -238,10 +280,43 @@ class BedrockLLM(BaseLLM):
aws_session_name,
aws_profile_name,
aws_role_name,
+ aws_web_identity_token,
) = params_to_check
### CHECK STS ###
- if aws_role_name is not None and aws_session_name is not None:
+ if (
+ aws_web_identity_token is not None
+ and aws_role_name is not None
+ and aws_session_name is not None
+ ):
+ oidc_token = get_secret(aws_web_identity_token)
+
+ if oidc_token is None:
+ raise BedrockError(
+ message="OIDC token could not be retrieved from secret manager.",
+ status_code=401,
+ )
+
+ sts_client = boto3.client("sts")
+
+ # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
+ # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
+ sts_response = sts_client.assume_role_with_web_identity(
+ RoleArn=aws_role_name,
+ RoleSessionName=aws_session_name,
+ WebIdentityToken=oidc_token,
+ DurationSeconds=3600,
+ )
+
+ session = boto3.Session(
+ aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
+ aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
+ aws_session_token=sts_response["Credentials"]["SessionToken"],
+ region_name=aws_region_name,
+ )
+
+ return session.get_credentials()
+ elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
@@ -252,7 +327,16 @@ class BedrockLLM(BaseLLM):
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
- return sts_response["Credentials"]
+ # Extract the credentials from the response and convert to Session Credentials
+ sts_credentials = sts_response["Credentials"]
+ from botocore.credentials import Credentials
+
+ credentials = Credentials(
+ access_key=sts_credentials["AccessKeyId"],
+ secret_key=sts_credentials["SecretAccessKey"],
+ token=sts_credentials["SessionToken"],
+ )
+ return credentials
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
@@ -280,7 +364,8 @@ class BedrockLLM(BaseLLM):
messages: List,
print_verbose,
encoding,
- ) -> ModelResponse:
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
+ provider = model.split(".")[0]
## LOGGING
logging_obj.post_call(
input=messages,
@@ -297,26 +382,210 @@ class BedrockLLM(BaseLLM):
raise BedrockError(message=response.text, status_code=422)
try:
- model_response.choices[0].message.content = completion_response["text"] # type: ignore
+ if provider == "cohere":
+ if "text" in completion_response:
+ outputText = completion_response["text"] # type: ignore
+ elif "generations" in completion_response:
+ outputText = completion_response["generations"][0]["text"]
+ model_response["finish_reason"] = map_finish_reason(
+ completion_response["generations"][0]["finish_reason"]
+ )
+ elif provider == "anthropic":
+ if model.startswith("anthropic.claude-3"):
+ json_schemas: dict = {}
+ _is_function_call = False
+ ## Handle Tool Calling
+ if "tools" in optional_params:
+ _is_function_call = True
+ for tool in optional_params["tools"]:
+ json_schemas[tool["function"]["name"]] = tool[
+ "function"
+ ].get("parameters", None)
+ outputText = completion_response.get("content")[0].get("text", None)
+ if outputText is not None and contains_tag(
+ "invoke", outputText
+ ): # OUTPUT PARSE FUNCTION CALL
+ function_name = extract_between_tags("tool_name", outputText)[0]
+ function_arguments_str = extract_between_tags(
+ "invoke", outputText
+ )[0].strip()
+ function_arguments_str = (
+ f"{function_arguments_str}"
+ )
+ function_arguments = parse_xml_params(
+ function_arguments_str,
+ json_schema=json_schemas.get(
+ function_name, None
+ ), # check if we have a json schema for this function name)
+ )
+ _message = litellm.Message(
+ tool_calls=[
+ {
+ "id": f"call_{uuid.uuid4()}",
+ "type": "function",
+ "function": {
+ "name": function_name,
+ "arguments": json.dumps(function_arguments),
+ },
+ }
+ ],
+ content=None,
+ )
+ model_response.choices[0].message = _message # type: ignore
+ model_response._hidden_params["original_response"] = (
+ outputText # allow user to access raw anthropic tool calling response
+ )
+ if (
+ _is_function_call == True
+ and stream is not None
+ and stream == True
+ ):
+ print_verbose(
+ f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
+ )
+ # return an iterator
+ streaming_model_response = ModelResponse(stream=True)
+ streaming_model_response.choices[0].finish_reason = getattr(
+ model_response.choices[0], "finish_reason", "stop"
+ )
+ # streaming_model_response.choices = [litellm.utils.StreamingChoices()]
+ streaming_choice = litellm.utils.StreamingChoices()
+ streaming_choice.index = model_response.choices[0].index
+ _tool_calls = []
+ print_verbose(
+ f"type of model_response.choices[0]: {type(model_response.choices[0])}"
+ )
+ print_verbose(
+ f"type of streaming_choice: {type(streaming_choice)}"
+ )
+ if isinstance(model_response.choices[0], litellm.Choices):
+ if getattr(
+ model_response.choices[0].message, "tool_calls", None
+ ) is not None and isinstance(
+ model_response.choices[0].message.tool_calls, list
+ ):
+ for tool_call in model_response.choices[
+ 0
+ ].message.tool_calls:
+ _tool_call = {**tool_call.dict(), "index": 0}
+ _tool_calls.append(_tool_call)
+ delta_obj = litellm.utils.Delta(
+ content=getattr(
+ model_response.choices[0].message, "content", None
+ ),
+ role=model_response.choices[0].message.role,
+ tool_calls=_tool_calls,
+ )
+ streaming_choice.delta = delta_obj
+ streaming_model_response.choices = [streaming_choice]
+ completion_stream = ModelResponseIterator(
+ model_response=streaming_model_response
+ )
+ print_verbose(
+ f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
+ )
+ return litellm.CustomStreamWrapper(
+ completion_stream=completion_stream,
+ model=model,
+ custom_llm_provider="cached_response",
+ logging_obj=logging_obj,
+ )
+
+ model_response["finish_reason"] = map_finish_reason(
+ completion_response.get("stop_reason", "")
+ )
+ _usage = litellm.Usage(
+ prompt_tokens=completion_response["usage"]["input_tokens"],
+ completion_tokens=completion_response["usage"]["output_tokens"],
+ total_tokens=completion_response["usage"]["input_tokens"]
+ + completion_response["usage"]["output_tokens"],
+ )
+ setattr(model_response, "usage", _usage)
+ else:
+ outputText = completion_response["completion"]
+
+ model_response["finish_reason"] = completion_response["stop_reason"]
+ elif provider == "ai21":
+ outputText = (
+ completion_response.get("completions")[0].get("data").get("text")
+ )
+ elif provider == "meta":
+ outputText = completion_response["generation"]
+ elif provider == "mistral":
+ outputText = completion_response["outputs"][0]["text"]
+ model_response["finish_reason"] = completion_response["outputs"][0][
+ "stop_reason"
+ ]
+ else: # amazon titan
+ outputText = completion_response.get("results")[0].get("outputText")
except Exception as e:
- raise BedrockError(message=response.text, status_code=422)
+ raise BedrockError(
+ message="Error processing={}, Received error={}".format(
+ response.text, str(e)
+ ),
+ status_code=422,
+ )
+
+ try:
+ if (
+ len(outputText) > 0
+ and hasattr(model_response.choices[0], "message")
+ and getattr(model_response.choices[0].message, "tool_calls", None)
+ is None
+ ):
+ model_response["choices"][0]["message"]["content"] = outputText
+ elif (
+ hasattr(model_response.choices[0], "message")
+ and getattr(model_response.choices[0].message, "tool_calls", None)
+ is not None
+ ):
+ pass
+ else:
+ raise Exception()
+ except:
+ raise BedrockError(
+ message=json.dumps(outputText), status_code=response.status_code
+ )
+
+ if stream and provider == "ai21":
+ streaming_model_response = ModelResponse(stream=True)
+ streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
+ 0
+ ].finish_reason
+ # streaming_model_response.choices = [litellm.utils.StreamingChoices()]
+ streaming_choice = litellm.utils.StreamingChoices()
+ streaming_choice.index = model_response.choices[0].index
+ delta_obj = litellm.utils.Delta(
+ content=getattr(model_response.choices[0].message, "content", None),
+ role=model_response.choices[0].message.role,
+ )
+ streaming_choice.delta = delta_obj
+ streaming_model_response.choices = [streaming_choice]
+ mri = ModelResponseIterator(model_response=streaming_model_response)
+ return CustomStreamWrapper(
+ completion_stream=mri,
+ model=model,
+ custom_llm_provider="cached_response",
+ logging_obj=logging_obj,
+ )
## CALCULATING USAGE - bedrock returns usage in the headers
- prompt_tokens = int(
- response.headers.get(
- "x-amzn-bedrock-input-token-count",
- len(encoding.encode("".join(m.get("content", "") for m in messages))),
- )
+ bedrock_input_tokens = response.headers.get(
+ "x-amzn-bedrock-input-token-count", None
)
+ bedrock_output_tokens = response.headers.get(
+ "x-amzn-bedrock-output-token-count", None
+ )
+
+ prompt_tokens = int(
+ bedrock_input_tokens or litellm.token_counter(messages=messages)
+ )
+
completion_tokens = int(
- response.headers.get(
- "x-amzn-bedrock-output-token-count",
- len(
- encoding.encode(
- model_response.choices[0].message.content, # type: ignore
- disallowed_special=(),
- )
- ),
+ bedrock_output_tokens
+ or litellm.token_counter(
+ text=model_response.choices[0].message.content, # type: ignore
+ count_response_tokens=True,
)
)
@@ -331,6 +600,16 @@ class BedrockLLM(BaseLLM):
return model_response
+ def encode_model_id(self, model_id: str) -> str:
+ """
+ Double encode the model ID to ensure it matches the expected double-encoded format.
+ Args:
+ model_id (str): The model ID to encode.
+ Returns:
+ str: The double-encoded model ID.
+ """
+ return urllib.parse.quote(model_id, safe="")
+
def completion(
self,
model: str,
@@ -359,6 +638,13 @@ class BedrockLLM(BaseLLM):
## SETUP ##
stream = optional_params.pop("stream", None)
+ modelId = optional_params.pop("model_id", None)
+ if modelId is not None:
+ modelId = self.encode_model_id(model_id=modelId)
+ else:
+ modelId = model
+
+ provider = model.split(".")[0]
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
@@ -371,6 +657,7 @@ class BedrockLLM(BaseLLM):
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
+ aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
### SET REGION NAME ###
if aws_region_name is None:
@@ -398,6 +685,7 @@ class BedrockLLM(BaseLLM):
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
+ aws_web_identity_token=aws_web_identity_token,
)
### SET RUNTIME ENDPOINT ###
@@ -414,19 +702,18 @@ class BedrockLLM(BaseLLM):
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
- if stream is not None and stream == True:
- endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
+ if (stream is not None and stream == True) and provider != "ai21":
+ endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
else:
- endpoint_url = f"{endpoint_url}/model/{model}/invoke"
+ endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
- provider = model.split(".")[0]
prompt, chat_history = self.convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
inference_params = copy.deepcopy(optional_params)
-
+ json_schemas: dict = {}
if provider == "cohere":
if model.startswith("cohere.command-r"):
## LOAD CONFIG
@@ -453,8 +740,114 @@ class BedrockLLM(BaseLLM):
True # cohere requires stream = True in inference params
)
data = json.dumps({"prompt": prompt, **inference_params})
+ elif provider == "anthropic":
+ if model.startswith("anthropic.claude-3"):
+ # Separate system prompt from rest of message
+ system_prompt_idx: list[int] = []
+ system_messages: list[str] = []
+ for idx, message in enumerate(messages):
+ if message["role"] == "system":
+ system_messages.append(message["content"])
+ system_prompt_idx.append(idx)
+ if len(system_prompt_idx) > 0:
+ inference_params["system"] = "\n".join(system_messages)
+ messages = [
+ i for j, i in enumerate(messages) if j not in system_prompt_idx
+ ]
+ # Format rest of message according to anthropic guidelines
+ messages = prompt_factory(
+ model=model, messages=messages, custom_llm_provider="anthropic_xml"
+ ) # type: ignore
+ ## LOAD CONFIG
+ config = litellm.AmazonAnthropicClaude3Config.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+ ## Handle Tool Calling
+ if "tools" in inference_params:
+ _is_function_call = True
+ for tool in inference_params["tools"]:
+ json_schemas[tool["function"]["name"]] = tool["function"].get(
+ "parameters", None
+ )
+ tool_calling_system_prompt = construct_tool_use_system_prompt(
+ tools=inference_params["tools"]
+ )
+ inference_params["system"] = (
+ inference_params.get("system", "\n")
+ + tool_calling_system_prompt
+ ) # add the anthropic tool calling prompt to the system prompt
+ inference_params.pop("tools")
+ data = json.dumps({"messages": messages, **inference_params})
+ else:
+ ## LOAD CONFIG
+ config = litellm.AmazonAnthropicConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+ data = json.dumps({"prompt": prompt, **inference_params})
+ elif provider == "ai21":
+ ## LOAD CONFIG
+ config = litellm.AmazonAI21Config.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+
+ data = json.dumps({"prompt": prompt, **inference_params})
+ elif provider == "mistral":
+ ## LOAD CONFIG
+ config = litellm.AmazonMistralConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+
+ data = json.dumps({"prompt": prompt, **inference_params})
+ elif provider == "amazon": # amazon titan
+ ## LOAD CONFIG
+ config = litellm.AmazonTitanConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+
+ data = json.dumps(
+ {
+ "inputText": prompt,
+ "textGenerationConfig": inference_params,
+ }
+ )
+ elif provider == "meta":
+ ## LOAD CONFIG
+ config = litellm.AmazonLlamaConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+ data = json.dumps({"prompt": prompt, **inference_params})
else:
- raise Exception("UNSUPPORTED PROVIDER")
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key="",
+ additional_args={
+ "complete_input_dict": inference_params,
+ },
+ )
+ raise Exception(
+ "Bedrock HTTPX: Unsupported provider={}, model={}".format(
+ provider, model
+ )
+ )
## COMPLETION CALL
@@ -482,7 +875,7 @@ class BedrockLLM(BaseLLM):
if acompletion:
if isinstance(client, HTTPHandler):
client = None
- if stream:
+ if stream == True and provider != "ai21":
return self.async_streaming(
model=model,
messages=messages,
@@ -511,7 +904,7 @@ class BedrockLLM(BaseLLM):
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
- stream=False,
+ stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
@@ -528,7 +921,7 @@ class BedrockLLM(BaseLLM):
self.client = HTTPHandler(**_params) # type: ignore
else:
self.client = client
- if stream is not None and stream == True:
+ if (stream is not None and stream == True) and provider != "ai21":
response = self.client.post(
url=prepped.url,
headers=prepped.headers, # type: ignore
@@ -541,7 +934,7 @@ class BedrockLLM(BaseLLM):
status_code=response.status_code, message=response.text
)
- decoder = AWSEventStreamDecoder()
+ decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper(
@@ -550,15 +943,24 @@ class BedrockLLM(BaseLLM):
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=streaming_response,
+ additional_args={"complete_input_dict": data},
+ )
return streaming_response
- response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
-
try:
+ response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
+ except httpx.TimeoutException as e:
+ raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
@@ -591,7 +993,7 @@ class BedrockLLM(BaseLLM):
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
- ) -> ModelResponse:
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
if client is None:
_params = {}
if timeout is not None:
@@ -602,12 +1004,20 @@ class BedrockLLM(BaseLLM):
else:
self.client = client # type: ignore
- response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
+ try:
+ response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise BedrockError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException as e:
+ raise BedrockError(status_code=408, message="Timeout error occurred.")
+
return self.process_response(
model=model,
response=response,
model_response=model_response,
- stream=stream,
+ stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj,
api_key="",
data=data,
@@ -635,26 +1045,20 @@ class BedrockLLM(BaseLLM):
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper:
- if client is None:
- _params = {}
- if timeout is not None:
- if isinstance(timeout, float) or isinstance(timeout, int):
- timeout = httpx.Timeout(timeout)
- _params["timeout"] = timeout
- self.client = AsyncHTTPHandler(**_params) # type: ignore
- else:
- self.client = client # type: ignore
+ # The call is not made here; instead, we prepare the necessary objects for the stream.
- response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore
-
- if response.status_code != 200:
- raise BedrockError(status_code=response.status_code, message=response.text)
-
- decoder = AWSEventStreamDecoder()
-
- completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper(
- completion_stream=completion_stream,
+ completion_stream=None,
+ make_call=partial(
+ make_call,
+ client=client,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ ),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
@@ -676,11 +1080,70 @@ def get_response_stream_shape():
class AWSEventStreamDecoder:
- def __init__(self) -> None:
+ def __init__(self, model: str) -> None:
from botocore.parsers import EventStreamJSONParser
+ self.model = model
self.parser = EventStreamJSONParser()
+ def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
+ text = ""
+ is_finished = False
+ finish_reason = ""
+ if "outputText" in chunk_data:
+ text = chunk_data["outputText"]
+ # ai21 mapping
+ if "ai21" in self.model: # fake ai21 streaming
+ text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
+ is_finished = True
+ finish_reason = "stop"
+ ######## bedrock.anthropic mappings ###############
+ elif "completion" in chunk_data: # not claude-3
+ text = chunk_data["completion"] # bedrock.anthropic
+ stop_reason = chunk_data.get("stop_reason", None)
+ if stop_reason != None:
+ is_finished = True
+ finish_reason = stop_reason
+ elif "delta" in chunk_data:
+ if chunk_data["delta"].get("text", None) is not None:
+ text = chunk_data["delta"]["text"]
+ stop_reason = chunk_data["delta"].get("stop_reason", None)
+ if stop_reason != None:
+ is_finished = True
+ finish_reason = stop_reason
+ ######## bedrock.mistral mappings ###############
+ elif "outputs" in chunk_data:
+ if (
+ len(chunk_data["outputs"]) == 1
+ and chunk_data["outputs"][0].get("text", None) is not None
+ ):
+ text = chunk_data["outputs"][0]["text"]
+ stop_reason = chunk_data.get("stop_reason", None)
+ if stop_reason != None:
+ is_finished = True
+ finish_reason = stop_reason
+ ######## bedrock.cohere mappings ###############
+ # meta mapping
+ elif "generation" in chunk_data:
+ text = chunk_data["generation"] # bedrock.meta
+ # cohere mapping
+ elif "text" in chunk_data:
+ text = chunk_data["text"] # bedrock.cohere
+ # cohere mapping for finish reason
+ elif "finish_reason" in chunk_data:
+ finish_reason = chunk_data["finish_reason"]
+ is_finished = True
+ elif chunk_data.get("completionReason", None):
+ is_finished = True
+ finish_reason = chunk_data["completionReason"]
+ return GenericStreamingChunk(
+ **{
+ "text": text,
+ "is_finished": is_finished,
+ "finish_reason": finish_reason,
+ }
+ )
+
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
@@ -693,12 +1156,7 @@ class AWSEventStreamDecoder:
if message:
# sse_event = ServerSentEvent(data=message, event="completion")
_data = json.loads(message)
- streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
- text=_data.get("text", ""),
- is_finished=_data.get("is_finished", False),
- finish_reason=_data.get("finish_reason", ""),
- )
- yield streaming_chunk
+ yield self._chunk_parser(chunk_data=_data)
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
@@ -713,12 +1171,7 @@ class AWSEventStreamDecoder:
message = self._parse_message_from_event(event)
if message:
_data = json.loads(message)
- streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
- text=_data.get("text", ""),
- is_finished=_data.get("is_finished", False),
- finish_reason=_data.get("finish_reason", ""),
- )
- yield streaming_chunk
+ yield self._chunk_parser(chunk_data=_data)
def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict()
diff --git a/litellm/llms/clarifai.py b/litellm/llms/clarifai.py
index e07a8d9e8..4610911e1 100644
--- a/litellm/llms/clarifai.py
+++ b/litellm/llms/clarifai.py
@@ -14,28 +14,25 @@ class ClarifaiError(Exception):
def __init__(self, status_code, message, url):
self.status_code = status_code
self.message = message
- self.request = httpx.Request(
- method="POST", url=url
- )
+ self.request = httpx.Request(method="POST", url=url)
self.response = httpx.Response(status_code=status_code, request=self.request)
- super().__init__(
- self.message
- )
+ super().__init__(self.message)
+
class ClarifaiConfig:
"""
Reference: https://clarifai.com/meta/Llama-2/models/llama2-70b-chat
- TODO fill in the details
"""
+
max_tokens: Optional[int] = None
temperature: Optional[int] = None
top_k: Optional[int] = None
def __init__(
- self,
- max_tokens: Optional[int] = None,
- temperature: Optional[int] = None,
- top_k: Optional[int] = None,
+ self,
+ max_tokens: Optional[int] = None,
+ temperature: Optional[int] = None,
+ top_k: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
@@ -60,6 +57,7 @@ class ClarifaiConfig:
and v is not None
}
+
def validate_environment(api_key):
headers = {
"accept": "application/json",
@@ -69,42 +67,37 @@ def validate_environment(api_key):
headers["Authorization"] = f"Bearer {api_key}"
return headers
-def completions_to_model(payload):
- # if payload["n"] != 1:
- # raise HTTPException(
- # status_code=422,
- # detail="Only one generation is supported. Please set candidate_count to 1.",
- # )
- params = {}
- if temperature := payload.get("temperature"):
- params["temperature"] = temperature
- if max_tokens := payload.get("max_tokens"):
- params["max_tokens"] = max_tokens
- return {
- "inputs": [{"data": {"text": {"raw": payload["prompt"]}}}],
- "model": {"output_info": {"params": params}},
-}
-
+def completions_to_model(payload):
+ # if payload["n"] != 1:
+ # raise HTTPException(
+ # status_code=422,
+ # detail="Only one generation is supported. Please set candidate_count to 1.",
+ # )
+
+ params = {}
+ if temperature := payload.get("temperature"):
+ params["temperature"] = temperature
+ if max_tokens := payload.get("max_tokens"):
+ params["max_tokens"] = max_tokens
+ return {
+ "inputs": [{"data": {"text": {"raw": payload["prompt"]}}}],
+ "model": {"output_info": {"params": params}},
+ }
+
+
def process_response(
- model,
- prompt,
- response,
- model_response,
- api_key,
- data,
- encoding,
- logging_obj
- ):
+ model, prompt, response, model_response, api_key, data, encoding, logging_obj
+):
logging_obj.post_call(
- input=prompt,
- api_key=api_key,
- original_response=response.text,
- additional_args={"complete_input_dict": data},
- )
- ## RESPONSE OBJECT
+ input=prompt,
+ api_key=api_key,
+ original_response=response.text,
+ additional_args={"complete_input_dict": data},
+ )
+ ## RESPONSE OBJECT
try:
- completion_response = response.json()
+ completion_response = response.json()
except Exception:
raise ClarifaiError(
message=response.text, status_code=response.status_code, url=model
@@ -119,7 +112,7 @@ def process_response(
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason="stop",
- index=idx + 1, #check
+ index=idx + 1, # check
message=message_obj,
)
choices_list.append(choice_obj)
@@ -143,53 +136,56 @@ def process_response(
)
return model_response
+
def convert_model_to_url(model: str, api_base: str):
user_id, app_id, model_id = model.split(".")
return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs"
+
def get_prompt_model_name(url: str):
clarifai_model_name = url.split("/")[-2]
if "claude" in clarifai_model_name:
return "anthropic", clarifai_model_name.replace("_", ".")
- if ("llama" in clarifai_model_name)or ("mistral" in clarifai_model_name):
+ if ("llama" in clarifai_model_name) or ("mistral" in clarifai_model_name):
return "", "meta-llama/llama-2-chat"
else:
return "", clarifai_model_name
+
async def async_completion(
- model: str,
- prompt: str,
- api_base: str,
- custom_prompt_dict: dict,
- model_response: ModelResponse,
- print_verbose: Callable,
- encoding,
- api_key,
- logging_obj,
- data=None,
- optional_params=None,
- litellm_params=None,
- logger_fn=None,
- headers={}):
-
- async_handler = AsyncHTTPHandler(
- timeout=httpx.Timeout(timeout=600.0, connect=5.0)
- )
+ model: str,
+ prompt: str,
+ api_base: str,
+ custom_prompt_dict: dict,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ api_key,
+ logging_obj,
+ data=None,
+ optional_params=None,
+ litellm_params=None,
+ logger_fn=None,
+ headers={},
+):
+
+ async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = await async_handler.post(
- api_base, headers=headers, data=json.dumps(data)
- )
-
- return process_response(
- model=model,
- prompt=prompt,
- response=response,
- model_response=model_response,
- api_key=api_key,
- data=data,
- encoding=encoding,
- logging_obj=logging_obj,
+ api_base, headers=headers, data=json.dumps(data)
)
+ return process_response(
+ model=model,
+ prompt=prompt,
+ response=response,
+ model_response=model_response,
+ api_key=api_key,
+ data=data,
+ encoding=encoding,
+ logging_obj=logging_obj,
+ )
+
+
def completion(
model: str,
messages: list,
@@ -207,14 +203,12 @@ def completion(
):
headers = validate_environment(api_key)
model = convert_model_to_url(model, api_base)
- prompt = " ".join(message["content"] for message in messages) # TODO
+ prompt = " ".join(message["content"] for message in messages) # TODO
## Load Config
config = litellm.ClarifaiConfig.get_config()
for k, v in config.items():
- if (
- k not in optional_params
- ):
+ if k not in optional_params:
optional_params[k] = v
custom_llm_provider, orig_model_name = get_prompt_model_name(model)
@@ -223,14 +217,14 @@ def completion(
model=orig_model_name,
messages=messages,
api_key=api_key,
- custom_llm_provider="clarifai"
+ custom_llm_provider="clarifai",
)
else:
prompt = prompt_factory(
model=orig_model_name,
messages=messages,
api_key=api_key,
- custom_llm_provider=custom_llm_provider
+ custom_llm_provider=custom_llm_provider,
)
# print(prompt); exit(0)
@@ -240,7 +234,6 @@ def completion(
}
data = completions_to_model(data)
-
## LOGGING
logging_obj.pre_call(
input=prompt,
@@ -251,7 +244,7 @@ def completion(
"api_base": api_base,
},
)
- if acompletion==True:
+ if acompletion == True:
return async_completion(
model=model,
prompt=prompt,
@@ -271,15 +264,17 @@ def completion(
else:
## COMPLETION CALL
response = requests.post(
- model,
- headers=headers,
- data=json.dumps(data),
- )
+ model,
+ headers=headers,
+ data=json.dumps(data),
+ )
# print(response.content); exit()
if response.status_code != 200:
- raise ClarifaiError(status_code=response.status_code, message=response.text, url=model)
-
+ raise ClarifaiError(
+ status_code=response.status_code, message=response.text, url=model
+ )
+
if "stream" in optional_params and optional_params["stream"] == True:
completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper(
@@ -287,11 +282,11 @@ def completion(
model=model,
custom_llm_provider="clarifai",
logging_obj=logging_obj,
- )
+ )
return stream_response
-
+
else:
- return process_response(
+ return process_response(
model=model,
prompt=prompt,
response=response,
@@ -299,8 +294,9 @@ def completion(
api_key=api_key,
data=data,
encoding=encoding,
- logging_obj=logging_obj)
-
+ logging_obj=logging_obj,
+ )
+
class ModelResponseIterator:
def __init__(self, model_response):
@@ -325,4 +321,4 @@ class ModelResponseIterator:
if self.is_done:
raise StopAsyncIteration
self.is_done = True
- return self.model_response
\ No newline at end of file
+ return self.model_response
diff --git a/litellm/llms/cohere.py b/litellm/llms/cohere.py
index 0ebdf38f1..14a66b54a 100644
--- a/litellm/llms/cohere.py
+++ b/litellm/llms/cohere.py
@@ -117,6 +117,7 @@ class CohereConfig:
def validate_environment(api_key):
headers = {
+ "Request-Source":"unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
diff --git a/litellm/llms/cohere_chat.py b/litellm/llms/cohere_chat.py
index e4de6ddcb..8ae839243 100644
--- a/litellm/llms/cohere_chat.py
+++ b/litellm/llms/cohere_chat.py
@@ -112,6 +112,7 @@ class CohereChatConfig:
def validate_environment(api_key):
headers = {
+ "Request-Source":"unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py
index 0adbd95bf..b91aaee2a 100644
--- a/litellm/llms/custom_httpx/http_handler.py
+++ b/litellm/llms/custom_httpx/http_handler.py
@@ -1,4 +1,5 @@
-import httpx, asyncio
+import litellm
+import httpx, asyncio, traceback, os
from typing import Optional, Union, Mapping, Any
# https://www.python-httpx.org/advanced/timeouts
@@ -7,8 +8,36 @@ _DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
class AsyncHTTPHandler:
def __init__(
- self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
+ self,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ concurrent_limit=1000,
):
+ async_proxy_mounts = None
+ # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
+ http_proxy = os.getenv("HTTP_PROXY", None)
+ https_proxy = os.getenv("HTTPS_PROXY", None)
+ no_proxy = os.getenv("NO_PROXY", None)
+ ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify))
+ cert = os.getenv(
+ "SSL_CERTIFICATE", litellm.ssl_certificate
+ ) # /path/to/client.pem
+
+ if http_proxy is not None and https_proxy is not None:
+ async_proxy_mounts = {
+ "http://": httpx.AsyncHTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
+ "https://": httpx.AsyncHTTPTransport(
+ proxy=httpx.Proxy(url=https_proxy)
+ ),
+ }
+ # assume no_proxy is a list of comma separated urls
+ if no_proxy is not None and isinstance(no_proxy, str):
+ no_proxy_urls = no_proxy.split(",")
+
+ for url in no_proxy_urls: # set no-proxy support for specific urls
+ async_proxy_mounts[url] = None # type: ignore
+
+ if timeout is None:
+ timeout = _DEFAULT_TIMEOUT
# Create a client with a connection pool
self.client = httpx.AsyncClient(
timeout=timeout,
@@ -16,6 +45,9 @@ class AsyncHTTPHandler:
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
+ verify=ssl_verify,
+ mounts=async_proxy_mounts,
+ cert=cert,
)
async def close(self):
@@ -39,15 +71,22 @@ class AsyncHTTPHandler:
self,
url: str,
data: Optional[Union[dict, str]] = None, # type: ignore
+ json: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
- req = self.client.build_request(
- "POST", url, data=data, params=params, headers=headers # type: ignore
- )
- response = await self.client.send(req, stream=stream)
- return response
+ try:
+ req = self.client.build_request(
+ "POST", url, data=data, json=json, params=params, headers=headers # type: ignore
+ )
+ response = await self.client.send(req, stream=stream)
+ response.raise_for_status()
+ return response
+ except httpx.HTTPStatusError as e:
+ raise e
+ except Exception as e:
+ raise e
def __del__(self) -> None:
try:
@@ -59,13 +98,35 @@ class AsyncHTTPHandler:
class HTTPHandler:
def __init__(
self,
- timeout: Optional[httpx.Timeout] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
concurrent_limit=1000,
client: Optional[httpx.Client] = None,
):
if timeout is None:
timeout = _DEFAULT_TIMEOUT
+ # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
+ http_proxy = os.getenv("HTTP_PROXY", None)
+ https_proxy = os.getenv("HTTPS_PROXY", None)
+ no_proxy = os.getenv("NO_PROXY", None)
+ ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify))
+ cert = os.getenv(
+ "SSL_CERTIFICATE", litellm.ssl_certificate
+ ) # /path/to/client.pem
+
+ sync_proxy_mounts = None
+ if http_proxy is not None and https_proxy is not None:
+ sync_proxy_mounts = {
+ "http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
+ "https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)),
+ }
+ # assume no_proxy is a list of comma separated urls
+ if no_proxy is not None and isinstance(no_proxy, str):
+ no_proxy_urls = no_proxy.split(",")
+
+ for url in no_proxy_urls: # set no-proxy support for specific urls
+ sync_proxy_mounts[url] = None # type: ignore
+
if client is None:
# Create a client with a connection pool
self.client = httpx.Client(
@@ -74,6 +135,9 @@ class HTTPHandler:
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
+ verify=ssl_verify,
+ mounts=sync_proxy_mounts,
+ cert=cert,
)
else:
self.client = client
diff --git a/litellm/llms/databricks.py b/litellm/llms/databricks.py
new file mode 100644
index 000000000..4fe475259
--- /dev/null
+++ b/litellm/llms/databricks.py
@@ -0,0 +1,718 @@
+# What is this?
+## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
+from functools import partial
+import os, types
+import json
+from enum import Enum
+import requests, copy # type: ignore
+import time
+from typing import Callable, Optional, List, Union, Tuple, Literal
+from litellm.utils import (
+ ModelResponse,
+ Usage,
+ map_finish_reason,
+ CustomStreamWrapper,
+ EmbeddingResponse,
+)
+import litellm
+from .prompt_templates.factory import prompt_factory, custom_prompt
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from .base import BaseLLM
+import httpx # type: ignore
+from litellm.types.llms.databricks import GenericStreamingChunk
+from litellm.types.utils import ProviderField
+
+
+class DatabricksError(Exception):
+ def __init__(self, status_code, message):
+ self.status_code = status_code
+ self.message = message
+ self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
+ self.response = httpx.Response(status_code=status_code, request=self.request)
+ super().__init__(
+ self.message
+ ) # Call the base class constructor with the parameters it needs
+
+
+class DatabricksConfig:
+ """
+ Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
+ """
+
+ max_tokens: Optional[int] = None
+ temperature: Optional[int] = None
+ top_p: Optional[int] = None
+ top_k: Optional[int] = None
+ stop: Optional[Union[List[str], str]] = None
+ n: Optional[int] = None
+
+ def __init__(
+ self,
+ max_tokens: Optional[int] = None,
+ temperature: Optional[int] = None,
+ top_p: Optional[int] = None,
+ top_k: Optional[int] = None,
+ stop: Optional[Union[List[str], str]] = None,
+ n: Optional[int] = None,
+ ) -> None:
+ locals_ = locals()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_required_params(self) -> List[ProviderField]:
+ """For a given provider, return it's required fields with a description"""
+ return [
+ ProviderField(
+ field_name="api_key",
+ field_type="string",
+ field_description="Your Databricks API Key.",
+ field_value="dapi...",
+ ),
+ ProviderField(
+ field_name="api_base",
+ field_type="string",
+ field_description="Your Databricks API Base.",
+ field_value="https://adb-..",
+ ),
+ ]
+
+ def get_supported_openai_params(self):
+ return ["stream", "stop", "temperature", "top_p", "max_tokens", "n"]
+
+ def map_openai_params(self, non_default_params: dict, optional_params: dict):
+ for param, value in non_default_params.items():
+ if param == "max_tokens":
+ optional_params["max_tokens"] = value
+ if param == "n":
+ optional_params["n"] = value
+ if param == "stream" and value == True:
+ optional_params["stream"] = value
+ if param == "temperature":
+ optional_params["temperature"] = value
+ if param == "top_p":
+ optional_params["top_p"] = value
+ if param == "stop":
+ optional_params["stop"] = value
+ return optional_params
+
+ def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
+ try:
+ text = ""
+ is_finished = False
+ finish_reason = None
+ logprobs = None
+ usage = None
+ original_chunk = None # this is used for function/tool calling
+ chunk_data = chunk_data.replace("data:", "")
+ chunk_data = chunk_data.strip()
+ if len(chunk_data) == 0 or chunk_data == "[DONE]":
+ return {
+ "text": "",
+ "is_finished": is_finished,
+ "finish_reason": finish_reason,
+ }
+ chunk_data_dict = json.loads(chunk_data)
+ str_line = litellm.ModelResponse(**chunk_data_dict, stream=True)
+
+ if len(str_line.choices) > 0:
+ if (
+ str_line.choices[0].delta is not None # type: ignore
+ and str_line.choices[0].delta.content is not None # type: ignore
+ ):
+ text = str_line.choices[0].delta.content # type: ignore
+ else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai
+ original_chunk = str_line
+ if str_line.choices[0].finish_reason:
+ is_finished = True
+ finish_reason = str_line.choices[0].finish_reason
+ if finish_reason == "content_filter":
+ if hasattr(str_line.choices[0], "content_filter_result"):
+ error_message = json.dumps(
+ str_line.choices[0].content_filter_result # type: ignore
+ )
+ else:
+ error_message = "Azure Response={}".format(
+ str(dict(str_line))
+ )
+ raise litellm.AzureOpenAIError(
+ status_code=400, message=error_message
+ )
+
+ # checking for logprobs
+ if (
+ hasattr(str_line.choices[0], "logprobs")
+ and str_line.choices[0].logprobs is not None
+ ):
+ logprobs = str_line.choices[0].logprobs
+ else:
+ logprobs = None
+
+ usage = getattr(str_line, "usage", None)
+
+ return GenericStreamingChunk(
+ text=text,
+ is_finished=is_finished,
+ finish_reason=finish_reason,
+ logprobs=logprobs,
+ original_chunk=original_chunk,
+ usage=usage,
+ )
+ except Exception as e:
+ raise e
+
+
+class DatabricksEmbeddingConfig:
+ """
+ Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
+ """
+
+ instruction: Optional[str] = (
+ None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
+ )
+
+ def __init__(self, instruction: Optional[str] = None) -> None:
+ locals_ = locals()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_supported_openai_params(
+ self,
+ ): # no optional openai embedding params supported
+ return []
+
+ def map_openai_params(self, non_default_params: dict, optional_params: dict):
+ return optional_params
+
+
+async def make_call(
+ client: AsyncHTTPHandler,
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+):
+ response = await client.post(api_base, headers=headers, data=data, stream=True)
+
+ if response.status_code != 200:
+ raise DatabricksError(status_code=response.status_code, message=response.text)
+
+ completion_stream = response.aiter_lines()
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=completion_stream, # Pass the completion stream for logging
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
+class DatabricksChatCompletion(BaseLLM):
+ def __init__(self) -> None:
+ super().__init__()
+
+ # makes headers for API call
+
+ def _validate_environment(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ endpoint_type: Literal["chat_completions", "embeddings"],
+ ) -> Tuple[str, dict]:
+ if api_key is None:
+ raise DatabricksError(
+ status_code=400,
+ message="Missing Databricks API Key - A call is being made to Databricks but no key is set either in the environment variables (DATABRICKS_API_KEY) or via params",
+ )
+
+ if api_base is None:
+ raise DatabricksError(
+ status_code=400,
+ message="Missing Databricks API Base - A call is being made to Databricks but no api base is set either in the environment variables (DATABRICKS_API_BASE) or via params",
+ )
+
+ headers = {
+ "Authorization": "Bearer {}".format(api_key),
+ "Content-Type": "application/json",
+ }
+
+ if endpoint_type == "chat_completions":
+ api_base = "{}/chat/completions".format(api_base)
+ elif endpoint_type == "embeddings":
+ api_base = "{}/embeddings".format(api_base)
+ return api_base, headers
+
+ def process_response(
+ self,
+ model: str,
+ response: Union[requests.Response, httpx.Response],
+ model_response: ModelResponse,
+ stream: bool,
+ logging_obj: litellm.utils.Logging,
+ optional_params: dict,
+ api_key: str,
+ data: Union[dict, str],
+ messages: List,
+ print_verbose,
+ encoding,
+ ) -> ModelResponse:
+ ## LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key=api_key,
+ original_response=response.text,
+ additional_args={"complete_input_dict": data},
+ )
+ print_verbose(f"raw model_response: {response.text}")
+ ## RESPONSE OBJECT
+ try:
+ completion_response = response.json()
+ except:
+ raise DatabricksError(
+ message=response.text, status_code=response.status_code
+ )
+ if "error" in completion_response:
+ raise DatabricksError(
+ message=str(completion_response["error"]),
+ status_code=response.status_code,
+ )
+ else:
+ text_content = ""
+ tool_calls = []
+ for content in completion_response["content"]:
+ if content["type"] == "text":
+ text_content += content["text"]
+ ## TOOL CALLING
+ elif content["type"] == "tool_use":
+ tool_calls.append(
+ {
+ "id": content["id"],
+ "type": "function",
+ "function": {
+ "name": content["name"],
+ "arguments": json.dumps(content["input"]),
+ },
+ }
+ )
+
+ _message = litellm.Message(
+ tool_calls=tool_calls,
+ content=text_content or None,
+ )
+ model_response.choices[0].message = _message # type: ignore
+ model_response._hidden_params["original_response"] = completion_response[
+ "content"
+ ] # allow user to access raw anthropic tool calling response
+
+ model_response.choices[0].finish_reason = map_finish_reason(
+ completion_response["stop_reason"]
+ )
+
+ ## CALCULATING USAGE
+ prompt_tokens = completion_response["usage"]["input_tokens"]
+ completion_tokens = completion_response["usage"]["output_tokens"]
+ total_tokens = prompt_tokens + completion_tokens
+
+ model_response["created"] = int(time.time())
+ model_response["model"] = model
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ )
+ setattr(model_response, "usage", usage) # type: ignore
+ return model_response
+
+ async def acompletion_stream_function(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ custom_prompt_dict: dict,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ api_key,
+ logging_obj,
+ stream,
+ data: dict,
+ optional_params=None,
+ litellm_params=None,
+ logger_fn=None,
+ headers={},
+ client: Optional[AsyncHTTPHandler] = None,
+ ) -> CustomStreamWrapper:
+
+ data["stream"] = True
+ streamwrapper = CustomStreamWrapper(
+ completion_stream=None,
+ make_call=partial(
+ make_call,
+ api_base=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ ),
+ model=model,
+ custom_llm_provider="databricks",
+ logging_obj=logging_obj,
+ )
+ return streamwrapper
+
+ async def acompletion_function(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ custom_prompt_dict: dict,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ api_key,
+ logging_obj,
+ stream,
+ data: dict,
+ optional_params: dict,
+ litellm_params=None,
+ logger_fn=None,
+ headers={},
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ ) -> ModelResponse:
+ if timeout is None:
+ timeout = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ self.async_handler = AsyncHTTPHandler(timeout=timeout)
+
+ try:
+ response = await self.async_handler.post(
+ api_base, headers=headers, data=json.dumps(data)
+ )
+ response.raise_for_status()
+
+ response_json = response.json()
+ except httpx.HTTPStatusError as e:
+ raise DatabricksError(
+ status_code=e.response.status_code,
+ message=response.text if response else str(e),
+ )
+ except httpx.TimeoutException as e:
+ raise DatabricksError(status_code=408, message="Timeout error occurred.")
+ except Exception as e:
+ raise DatabricksError(status_code=500, message=str(e))
+
+ return ModelResponse(**response_json)
+
+ def completion(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ custom_prompt_dict: dict,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ api_key,
+ logging_obj,
+ optional_params: dict,
+ acompletion=None,
+ litellm_params=None,
+ logger_fn=None,
+ headers={},
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ ):
+ api_base, headers = self._validate_environment(
+ api_base=api_base, api_key=api_key, endpoint_type="chat_completions"
+ )
+ ## Load Config
+ config = litellm.DatabricksConfig().get_config()
+ for k, v in config.items():
+ if (
+ k not in optional_params
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
+ optional_params[k] = v
+
+ stream = optional_params.pop("stream", None)
+
+ data = {
+ "model": model,
+ "messages": messages,
+ **optional_params,
+ }
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+ if acompletion == True:
+ if client is not None and isinstance(client, HTTPHandler):
+ client = None
+ if (
+ stream is not None and stream == True
+ ): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
+ print_verbose("makes async anthropic streaming POST request")
+ data["stream"] = stream
+ return self.acompletion_stream_function(
+ model=model,
+ messages=messages,
+ data=data,
+ api_base=api_base,
+ custom_prompt_dict=custom_prompt_dict,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ api_key=api_key,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ stream=stream,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ headers=headers,
+ client=client,
+ )
+ else:
+ return self.acompletion_function(
+ model=model,
+ messages=messages,
+ data=data,
+ api_base=api_base,
+ custom_prompt_dict=custom_prompt_dict,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ api_key=api_key,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ stream=stream,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ headers=headers,
+ timeout=timeout,
+ )
+ else:
+ if client is None or isinstance(client, AsyncHTTPHandler):
+ self.client = HTTPHandler(timeout=timeout) # type: ignore
+ else:
+ self.client = client
+ ## COMPLETION CALL
+ if (
+ stream is not None and stream == True
+ ): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
+ print_verbose("makes dbrx streaming POST request")
+ data["stream"] = stream
+ try:
+ response = self.client.post(
+ api_base, headers=headers, data=json.dumps(data), stream=stream
+ )
+ response.raise_for_status()
+ completion_stream = response.iter_lines()
+ except httpx.HTTPStatusError as e:
+ raise DatabricksError(
+ status_code=e.response.status_code, message=response.text
+ )
+ except httpx.TimeoutException as e:
+ raise DatabricksError(
+ status_code=408, message="Timeout error occurred."
+ )
+ except Exception as e:
+ raise DatabricksError(status_code=408, message=str(e))
+
+ streaming_response = CustomStreamWrapper(
+ completion_stream=completion_stream,
+ model=model,
+ custom_llm_provider="databricks",
+ logging_obj=logging_obj,
+ )
+ return streaming_response
+
+ else:
+ try:
+ response = self.client.post(
+ api_base, headers=headers, data=json.dumps(data)
+ )
+ response.raise_for_status()
+
+ response_json = response.json()
+ except httpx.HTTPStatusError as e:
+ raise DatabricksError(
+ status_code=e.response.status_code, message=response.text
+ )
+ except httpx.TimeoutException as e:
+ raise DatabricksError(
+ status_code=408, message="Timeout error occurred."
+ )
+ except Exception as e:
+ raise DatabricksError(status_code=500, message=str(e))
+
+ return ModelResponse(**response_json)
+
+ async def aembedding(
+ self,
+ input: list,
+ data: dict,
+ model_response: ModelResponse,
+ timeout: float,
+ api_key: str,
+ api_base: str,
+ logging_obj,
+ headers: dict,
+ client=None,
+ ) -> EmbeddingResponse:
+ response = None
+ try:
+ if client is None or isinstance(client, AsyncHTTPHandler):
+ self.async_client = AsyncHTTPHandler(timeout=timeout) # type: ignore
+ else:
+ self.async_client = client
+
+ try:
+ response = await self.async_client.post(
+ api_base,
+ headers=headers,
+ data=json.dumps(data),
+ ) # type: ignore
+
+ response.raise_for_status()
+
+ response_json = response.json()
+ except httpx.HTTPStatusError as e:
+ raise DatabricksError(
+ status_code=e.response.status_code,
+ message=response.text if response else str(e),
+ )
+ except httpx.TimeoutException as e:
+ raise DatabricksError(
+ status_code=408, message="Timeout error occurred."
+ )
+ except Exception as e:
+ raise DatabricksError(status_code=500, message=str(e))
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=response_json,
+ )
+ return EmbeddingResponse(**response_json)
+ except Exception as e:
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key=api_key,
+ original_response=str(e),
+ )
+ raise e
+
+ def embedding(
+ self,
+ model: str,
+ input: list,
+ timeout: float,
+ logging_obj,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ optional_params: dict,
+ model_response: Optional[litellm.utils.EmbeddingResponse] = None,
+ client=None,
+ aembedding=None,
+ ) -> EmbeddingResponse:
+ api_base, headers = self._validate_environment(
+ api_base=api_base, api_key=api_key, endpoint_type="embeddings"
+ )
+ model = model
+ data = {"model": model, "input": input, **optional_params}
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=input,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data, "api_base": api_base},
+ )
+
+ if aembedding == True:
+ return self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, headers=headers) # type: ignore
+ if client is None or isinstance(client, AsyncHTTPHandler):
+ self.client = HTTPHandler(timeout=timeout) # type: ignore
+ else:
+ self.client = client
+
+ ## EMBEDDING CALL
+ try:
+ response = self.client.post(
+ api_base,
+ headers=headers,
+ data=json.dumps(data),
+ ) # type: ignore
+
+ response.raise_for_status() # type: ignore
+
+ response_json = response.json() # type: ignore
+ except httpx.HTTPStatusError as e:
+ raise DatabricksError(
+ status_code=e.response.status_code,
+ message=response.text if response else str(e),
+ )
+ except httpx.TimeoutException as e:
+ raise DatabricksError(status_code=408, message="Timeout error occurred.")
+ except Exception as e:
+ raise DatabricksError(status_code=500, message=str(e))
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=response_json,
+ )
+
+ return litellm.EmbeddingResponse(**response_json)
diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py
index 9c9b5e898..283878056 100644
--- a/litellm/llms/ollama.py
+++ b/litellm/llms/ollama.py
@@ -45,6 +45,8 @@ class OllamaConfig:
- `temperature` (float): The temperature of the model. Increasing the temperature will make the model answer more creatively. Default: 0.8. Example usage: temperature 0.7
+ - `seed` (int): Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. Example usage: seed 42
+
- `stop` (string[]): Sets the stop sequences to use. Example usage: stop "AI assistant:"
- `tfs_z` (float): Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. Default: 1. Example usage: tfs_z 1
@@ -69,6 +71,7 @@ class OllamaConfig:
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
+ seed: Optional[int] = None
stop: Optional[list] = (
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
)
@@ -90,6 +93,7 @@ class OllamaConfig:
repeat_last_n: Optional[int] = None,
repeat_penalty: Optional[float] = None,
temperature: Optional[float] = None,
+ seed: Optional[int] = None,
stop: Optional[list] = None,
tfs_z: Optional[float] = None,
num_predict: Optional[int] = None,
@@ -120,6 +124,44 @@ class OllamaConfig:
)
and v is not None
}
+ def get_supported_openai_params(
+ self,
+ ):
+ return [
+ "max_tokens",
+ "stream",
+ "top_p",
+ "temperature",
+ "seed",
+ "frequency_penalty",
+ "stop",
+ "response_format",
+ ]
+
+# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
+# and convert to jpeg if necessary.
+def _convert_image(image):
+ import base64, io
+ try:
+ from PIL import Image
+ except:
+ raise Exception(
+ "ollama image conversion failed please run `pip install Pillow`"
+ )
+
+ orig = image
+ if image.startswith("data:"):
+ image = image.split(",")[-1]
+ try:
+ image_data = Image.open(io.BytesIO(base64.b64decode(image)))
+ if image_data.format in ["JPEG", "PNG"]:
+ return image
+ except:
+ return orig
+ jpeg_image = io.BytesIO()
+ image_data.convert("RGB").save(jpeg_image, "JPEG")
+ jpeg_image.seek(0)
+ return base64.b64encode(jpeg_image.getvalue()).decode("utf-8")
# ollama implementation
@@ -158,7 +200,7 @@ def get_ollama_response(
if format is not None:
data["format"] = format
if images is not None:
- data["images"] = images
+ data["images"] = [_convert_image(image) for image in images]
## LOGGING
logging_obj.pre_call(
diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py
index d1ff4953f..a05807722 100644
--- a/litellm/llms/ollama_chat.py
+++ b/litellm/llms/ollama_chat.py
@@ -45,6 +45,8 @@ class OllamaChatConfig:
- `temperature` (float): The temperature of the model. Increasing the temperature will make the model answer more creatively. Default: 0.8. Example usage: temperature 0.7
+ - `seed` (int): Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. Example usage: seed 42
+
- `stop` (string[]): Sets the stop sequences to use. Example usage: stop "AI assistant:"
- `tfs_z` (float): Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. Default: 1. Example usage: tfs_z 1
@@ -69,6 +71,7 @@ class OllamaChatConfig:
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
+ seed: Optional[int] = None
stop: Optional[list] = (
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
)
@@ -90,6 +93,7 @@ class OllamaChatConfig:
repeat_last_n: Optional[int] = None,
repeat_penalty: Optional[float] = None,
temperature: Optional[float] = None,
+ seed: Optional[int] = None,
stop: Optional[list] = None,
tfs_z: Optional[float] = None,
num_predict: Optional[int] = None,
@@ -130,6 +134,7 @@ class OllamaChatConfig:
"stream",
"top_p",
"temperature",
+ "seed",
"frequency_penalty",
"stop",
"tools",
@@ -146,6 +151,8 @@ class OllamaChatConfig:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
+ if param == "seed":
+ optional_params["seed"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "frequency_penalty":
diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py
index 7acbdfae0..dec86d35d 100644
--- a/litellm/llms/openai.py
+++ b/litellm/llms/openai.py
@@ -6,7 +6,8 @@ from typing import (
Literal,
Iterable,
)
-from typing_extensions import override
+import hashlib
+from typing_extensions import override, overload
from pydantic import BaseModel
import types, time, json, traceback
import httpx
@@ -21,11 +22,12 @@ from litellm.utils import (
TranscriptionResponse,
TextCompletionResponse,
)
-from typing import Callable, Optional
+from typing import Callable, Optional, Coroutine
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI
from ..types.llms.openai import *
+import openai
class OpenAIError(Exception):
@@ -96,7 +98,7 @@ class MistralConfig:
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
) -> None:
- locals_ = locals()
+ locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@@ -157,6 +159,102 @@ class MistralConfig:
)
if param == "seed":
optional_params["extra_body"] = {"random_seed": value}
+ if param == "response_format":
+ optional_params["response_format"] = value
+ return optional_params
+
+
+class DeepInfraConfig:
+ """
+ Reference: https://deepinfra.com/docs/advanced/openai_api
+
+ The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
+ """
+
+ frequency_penalty: Optional[int] = None
+ function_call: Optional[Union[str, dict]] = None
+ functions: Optional[list] = None
+ logit_bias: Optional[dict] = None
+ max_tokens: Optional[int] = None
+ n: Optional[int] = None
+ presence_penalty: Optional[int] = None
+ stop: Optional[Union[str, list]] = None
+ temperature: Optional[int] = None
+ top_p: Optional[int] = None
+ response_format: Optional[dict] = None
+ tools: Optional[list] = None
+ tool_choice: Optional[Union[str, dict]] = None
+
+ def __init__(
+ self,
+ frequency_penalty: Optional[int] = None,
+ function_call: Optional[Union[str, dict]] = None,
+ functions: Optional[list] = None,
+ logit_bias: Optional[dict] = None,
+ max_tokens: Optional[int] = None,
+ n: Optional[int] = None,
+ presence_penalty: Optional[int] = None,
+ stop: Optional[Union[str, list]] = None,
+ temperature: Optional[int] = None,
+ top_p: Optional[int] = None,
+ response_format: Optional[dict] = None,
+ tools: Optional[list] = None,
+ tool_choice: Optional[Union[str, dict]] = None,
+ ) -> None:
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_supported_openai_params(self):
+ return [
+ "stream",
+ "frequency_penalty",
+ "function_call",
+ "functions",
+ "logit_bias",
+ "max_tokens",
+ "n",
+ "presence_penalty",
+ "stop",
+ "temperature",
+ "top_p",
+ "response_format",
+ "tools",
+ "tool_choice",
+ ]
+
+ def map_openai_params(
+ self, non_default_params: dict, optional_params: dict, model: str
+ ):
+ supported_openai_params = self.get_supported_openai_params()
+ for param, value in non_default_params.items():
+ if (
+ param == "temperature"
+ and value == 0
+ and model == "mistralai/Mistral-7B-Instruct-v0.1"
+ ): # this model does no support temperature == 0
+ value = 0.0001 # close to 0
+ if param in supported_openai_params:
+ optional_params[param] = value
return optional_params
@@ -197,6 +295,7 @@ class OpenAIConfig:
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
+ response_format: Optional[dict] = None
def __init__(
self,
@@ -210,8 +309,9 @@ class OpenAIConfig:
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
+ response_format: Optional[dict] = None,
) -> None:
- locals_ = locals()
+ locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@@ -234,6 +334,52 @@ class OpenAIConfig:
and v is not None
}
+ def get_supported_openai_params(self, model: str) -> list:
+ base_params = [
+ "frequency_penalty",
+ "logit_bias",
+ "logprobs",
+ "top_logprobs",
+ "max_tokens",
+ "n",
+ "presence_penalty",
+ "seed",
+ "stop",
+ "stream",
+ "stream_options",
+ "temperature",
+ "top_p",
+ "tools",
+ "tool_choice",
+ "function_call",
+ "functions",
+ "max_retries",
+ "extra_headers",
+ ] # works across all models
+
+ model_specific_params = []
+ if (
+ model != "gpt-3.5-turbo-16k" and model != "gpt-4"
+ ): # gpt-4 does not support 'response_format'
+ model_specific_params.append("response_format")
+
+ if (
+ model in litellm.open_ai_chat_completion_models
+ ) or model in litellm.open_ai_text_completion_models:
+ model_specific_params.append(
+ "user"
+ ) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
+ return base_params + model_specific_params
+
+ def map_openai_params(
+ self, non_default_params: dict, optional_params: dict, model: str
+ ) -> dict:
+ supported_openai_params = self.get_supported_openai_params(model)
+ for param, value in non_default_params.items():
+ if param in supported_openai_params:
+ optional_params[param] = value
+ return optional_params
+
class OpenAITextCompletionConfig:
"""
@@ -294,7 +440,7 @@ class OpenAITextCompletionConfig:
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> None:
- locals_ = locals()
+ locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@@ -359,10 +505,69 @@ class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
+ def _get_openai_client(
+ self,
+ is_async: bool,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
+ max_retries: Optional[int] = None,
+ organization: Optional[str] = None,
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ ):
+ args = locals()
+ if client is None:
+ if not isinstance(max_retries, int):
+ raise OpenAIError(
+ status_code=422,
+ message="max retries must be an int. Passed in value: {}".format(
+ max_retries
+ ),
+ )
+ # Creating a new OpenAI Client
+ # check in memory cache before creating a new one
+ # Convert the API key to bytes
+ hashed_api_key = None
+ if api_key is not None:
+ hash_object = hashlib.sha256(api_key.encode())
+ # Hexadecimal representation of the hash
+ hashed_api_key = hash_object.hexdigest()
+
+ _cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization},is_async={is_async}"
+
+ if _cache_key in litellm.in_memory_llm_clients_cache:
+ return litellm.in_memory_llm_clients_cache[_cache_key]
+ if is_async:
+ _new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ http_client=litellm.aclient_session,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ )
+ else:
+ _new_client = OpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ http_client=litellm.client_session,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ )
+
+ ## SAVE CACHE KEY
+ litellm.in_memory_llm_clients_cache[_cache_key] = _new_client
+ return _new_client
+
+ else:
+ return client
+
def completion(
self,
model_response: ModelResponse,
timeout: Union[float, httpx.Timeout],
+ optional_params: dict,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable] = None,
@@ -370,7 +575,6 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str] = None,
acompletion: bool = False,
logging_obj=None,
- optional_params=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
@@ -465,17 +669,16 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError(
status_code=422, message="max retries must be an int"
)
- if client is None:
- openai_client = OpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.client_session,
- timeout=timeout,
- max_retries=max_retries,
- organization=organization,
- )
- else:
- openai_client = client
+
+ openai_client = self._get_openai_client(
+ is_async=False,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
## LOGGING
logging_obj.pre_call(
@@ -555,17 +758,15 @@ class OpenAIChatCompletion(BaseLLM):
):
response = None
try:
- if client is None:
- openai_aclient = AsyncOpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.aclient_session,
- timeout=timeout,
- max_retries=max_retries,
- organization=organization,
- )
- else:
- openai_aclient = client
+ openai_aclient = self._get_openai_client(
+ is_async=True,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
## LOGGING
logging_obj.pre_call(
@@ -609,17 +810,15 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=None,
headers=None,
):
- if client is None:
- openai_client = OpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.client_session,
- timeout=timeout,
- max_retries=max_retries,
- organization=organization,
- )
- else:
- openai_client = client
+ openai_client = self._get_openai_client(
+ is_async=False,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
## LOGGING
logging_obj.pre_call(
input=data["messages"],
@@ -656,17 +855,15 @@ class OpenAIChatCompletion(BaseLLM):
):
response = None
try:
- if client is None:
- openai_aclient = AsyncOpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.aclient_session,
- timeout=timeout,
- max_retries=max_retries,
- organization=organization,
- )
- else:
- openai_aclient = client
+ openai_aclient = self._get_openai_client(
+ is_async=True,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
## LOGGING
logging_obj.pre_call(
input=data["messages"],
@@ -720,16 +917,14 @@ class OpenAIChatCompletion(BaseLLM):
):
response = None
try:
- if client is None:
- openai_aclient = AsyncOpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.aclient_session,
- timeout=timeout,
- max_retries=max_retries,
- )
- else:
- openai_aclient = client
+ openai_aclient = self._get_openai_client(
+ is_async=True,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump()
## LOGGING
@@ -754,10 +949,10 @@ class OpenAIChatCompletion(BaseLLM):
model: str,
input: list,
timeout: float,
+ logging_obj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
- logging_obj=None,
optional_params=None,
client=None,
aembedding=None,
@@ -777,19 +972,18 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data, "api_base": api_base},
)
- if aembedding == True:
+ if aembedding is True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
- if client is None:
- openai_client = OpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.client_session,
- timeout=timeout,
- max_retries=max_retries,
- )
- else:
- openai_client = client
+
+ openai_client = self._get_openai_client(
+ is_async=False,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
## COMPLETION CALL
response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore
@@ -825,16 +1019,16 @@ class OpenAIChatCompletion(BaseLLM):
):
response = None
try:
- if client is None:
- openai_aclient = AsyncOpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.aclient_session,
- timeout=timeout,
- max_retries=max_retries,
- )
- else:
- openai_aclient = client
+
+ openai_aclient = self._get_openai_client(
+ is_async=True,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump()
## LOGGING
@@ -879,16 +1073,14 @@ class OpenAIChatCompletion(BaseLLM):
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
- if client is None:
- openai_client = OpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.client_session,
- timeout=timeout,
- max_retries=max_retries,
- )
- else:
- openai_client = client
+ openai_client = self._get_openai_client(
+ is_async=False,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
## LOGGING
logging_obj.pre_call(
@@ -946,14 +1138,14 @@ class OpenAIChatCompletion(BaseLLM):
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
- api_key: Optional[str] = None,
- api_base: Optional[str] = None,
+ api_key: Optional[str],
+ api_base: Optional[str],
client=None,
logging_obj=None,
atranscription: bool = False,
):
data = {"model": model, "file": audio_file, **optional_params}
- if atranscription == True:
+ if atranscription is True:
return self.async_audio_transcriptions(
audio_file=audio_file,
data=data,
@@ -965,16 +1157,14 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=max_retries,
logging_obj=logging_obj,
)
- if client is None:
- openai_client = OpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.client_session,
- timeout=timeout,
- max_retries=max_retries,
- )
- else:
- openai_client = client
+
+ openai_client = self._get_openai_client(
+ is_async=False,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ )
response = openai_client.audio.transcriptions.create(
**data, timeout=timeout # type: ignore
)
@@ -1003,18 +1193,16 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=None,
logging_obj=None,
):
- response = None
try:
- if client is None:
- openai_aclient = AsyncOpenAI(
- api_key=api_key,
- base_url=api_base,
- http_client=litellm.aclient_session,
- timeout=timeout,
- max_retries=max_retries,
- )
- else:
- openai_aclient = client
+ openai_aclient = self._get_openai_client(
+ is_async=True,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
response = await openai_aclient.audio.transcriptions.create(
**data, timeout=timeout
) # type: ignore
@@ -1037,6 +1225,87 @@ class OpenAIChatCompletion(BaseLLM):
)
raise e
+ def audio_speech(
+ self,
+ model: str,
+ input: str,
+ voice: str,
+ optional_params: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ organization: Optional[str],
+ project: Optional[str],
+ max_retries: int,
+ timeout: Union[float, httpx.Timeout],
+ aspeech: Optional[bool] = None,
+ client=None,
+ ) -> HttpxBinaryResponseContent:
+
+ if aspeech is not None and aspeech is True:
+ return self.async_audio_speech(
+ model=model,
+ input=input,
+ voice=voice,
+ optional_params=optional_params,
+ api_key=api_key,
+ api_base=api_base,
+ organization=organization,
+ project=project,
+ max_retries=max_retries,
+ timeout=timeout,
+ client=client,
+ ) # type: ignore
+
+ openai_client = self._get_openai_client(
+ is_async=False,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = openai_client.audio.speech.create(
+ model=model,
+ voice=voice, # type: ignore
+ input=input,
+ **optional_params,
+ )
+ return response
+
+ async def async_audio_speech(
+ self,
+ model: str,
+ input: str,
+ voice: str,
+ optional_params: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ organization: Optional[str],
+ project: Optional[str],
+ max_retries: int,
+ timeout: Union[float, httpx.Timeout],
+ client=None,
+ ) -> HttpxBinaryResponseContent:
+
+ openai_client = self._get_openai_client(
+ is_async=True,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ )
+
+ response = await openai_client.audio.speech.create(
+ model=model,
+ voice=voice, # type: ignore
+ input=input,
+ **optional_params,
+ )
+
+ return response
+
async def ahealth_check(
self,
model: Optional[str],
@@ -1358,6 +1627,322 @@ class OpenAITextCompletion(BaseLLM):
yield transformed_chunk
+class OpenAIFilesAPI(BaseLLM):
+ """
+ OpenAI methods to support for batches
+ - create_file()
+ - retrieve_file()
+ - list_files()
+ - delete_file()
+ - file_content()
+ - update_file()
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def get_openai_client(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ _is_async: bool = False,
+ ) -> Optional[Union[OpenAI, AsyncOpenAI]]:
+ received_args = locals()
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
+ if client is None:
+ data = {}
+ for k, v in received_args.items():
+ if k == "self" or k == "client" or k == "_is_async":
+ pass
+ elif k == "api_base" and v is not None:
+ data["base_url"] = v
+ elif v is not None:
+ data[k] = v
+ if _is_async is True:
+ openai_client = AsyncOpenAI(**data)
+ else:
+ openai_client = OpenAI(**data) # type: ignore
+ else:
+ openai_client = client
+
+ return openai_client
+
+ async def acreate_file(
+ self,
+ create_file_data: CreateFileRequest,
+ openai_client: AsyncOpenAI,
+ ) -> FileObject:
+ response = await openai_client.files.create(**create_file_data)
+ return response
+
+ def create_file(
+ self,
+ _is_async: bool,
+ create_file_data: CreateFileRequest,
+ api_base: str,
+ api_key: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ ) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+
+ if _is_async is True:
+ if not isinstance(openai_client, AsyncOpenAI):
+ raise ValueError(
+ "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
+ )
+ return self.acreate_file( # type: ignore
+ create_file_data=create_file_data, openai_client=openai_client
+ )
+ response = openai_client.files.create(**create_file_data)
+ return response
+
+ async def afile_content(
+ self,
+ file_content_request: FileContentRequest,
+ openai_client: AsyncOpenAI,
+ ) -> HttpxBinaryResponseContent:
+ response = await openai_client.files.content(**file_content_request)
+ return response
+
+ def file_content(
+ self,
+ _is_async: bool,
+ file_content_request: FileContentRequest,
+ api_base: str,
+ api_key: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ ) -> Union[
+ HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
+ ]:
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+
+ if _is_async is True:
+ if not isinstance(openai_client, AsyncOpenAI):
+ raise ValueError(
+ "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
+ )
+ return self.afile_content( # type: ignore
+ file_content_request=file_content_request,
+ openai_client=openai_client,
+ )
+ response = openai_client.files.content(**file_content_request)
+
+ return response
+
+
+class OpenAIBatchesAPI(BaseLLM):
+ """
+ OpenAI methods to support for batches
+ - create_batch()
+ - retrieve_batch()
+ - cancel_batch()
+ - list_batch()
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def get_openai_client(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ _is_async: bool = False,
+ ) -> Optional[Union[OpenAI, AsyncOpenAI]]:
+ received_args = locals()
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
+ if client is None:
+ data = {}
+ for k, v in received_args.items():
+ if k == "self" or k == "client" or k == "_is_async":
+ pass
+ elif k == "api_base" and v is not None:
+ data["base_url"] = v
+ elif v is not None:
+ data[k] = v
+ if _is_async is True:
+ openai_client = AsyncOpenAI(**data)
+ else:
+ openai_client = OpenAI(**data) # type: ignore
+ else:
+ openai_client = client
+
+ return openai_client
+
+ async def acreate_batch(
+ self,
+ create_batch_data: CreateBatchRequest,
+ openai_client: AsyncOpenAI,
+ ) -> Batch:
+ response = await openai_client.batches.create(**create_batch_data)
+ return response
+
+ def create_batch(
+ self,
+ _is_async: bool,
+ create_batch_data: CreateBatchRequest,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ ) -> Union[Batch, Coroutine[Any, Any, Batch]]:
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+
+ if _is_async is True:
+ if not isinstance(openai_client, AsyncOpenAI):
+ raise ValueError(
+ "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
+ )
+ return self.acreate_batch( # type: ignore
+ create_batch_data=create_batch_data, openai_client=openai_client
+ )
+ response = openai_client.batches.create(**create_batch_data)
+ return response
+
+ async def aretrieve_batch(
+ self,
+ retrieve_batch_data: RetrieveBatchRequest,
+ openai_client: AsyncOpenAI,
+ ) -> Batch:
+ response = await openai_client.batches.retrieve(**retrieve_batch_data)
+ return response
+
+ def retrieve_batch(
+ self,
+ _is_async: bool,
+ retrieve_batch_data: RetrieveBatchRequest,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI] = None,
+ ):
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+
+ if _is_async is True:
+ if not isinstance(openai_client, AsyncOpenAI):
+ raise ValueError(
+ "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
+ )
+ return self.aretrieve_batch( # type: ignore
+ retrieve_batch_data=retrieve_batch_data, openai_client=openai_client
+ )
+ response = openai_client.batches.retrieve(**retrieve_batch_data)
+ return response
+
+ def cancel_batch(
+ self,
+ _is_async: bool,
+ cancel_batch_data: CancelBatchRequest,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI] = None,
+ ):
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+ response = openai_client.batches.cancel(**cancel_batch_data)
+ return response
+
+ # def list_batch(
+ # self,
+ # list_batch_data: ListBatchRequest,
+ # api_key: Optional[str],
+ # api_base: Optional[str],
+ # timeout: Union[float, httpx.Timeout],
+ # max_retries: Optional[int],
+ # organization: Optional[str],
+ # client: Optional[OpenAI] = None,
+ # ):
+ # openai_client: OpenAI = self.get_openai_client(
+ # api_key=api_key,
+ # api_base=api_base,
+ # timeout=timeout,
+ # max_retries=max_retries,
+ # organization=organization,
+ # client=client,
+ # )
+ # response = openai_client.batches.list(**list_batch_data)
+ # return response
+
+
class OpenAIAssistantsAPI(BaseLLM):
def __init__(self) -> None:
super().__init__()
@@ -1387,8 +1972,85 @@ class OpenAIAssistantsAPI(BaseLLM):
return openai_client
+ def async_get_openai_client(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI] = None,
+ ) -> AsyncOpenAI:
+ received_args = locals()
+ if client is None:
+ data = {}
+ for k, v in received_args.items():
+ if k == "self" or k == "client":
+ pass
+ elif k == "api_base" and v is not None:
+ data["base_url"] = v
+ elif v is not None:
+ data[k] = v
+ openai_client = AsyncOpenAI(**data) # type: ignore
+ else:
+ openai_client = client
+
+ return openai_client
+
### ASSISTANTS ###
+ async def async_get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ ) -> AsyncCursorPage[Assistant]:
+ openai_client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+
+ response = await openai_client.beta.assistants.list()
+
+ return response
+
+ # fmt: off
+
+ @overload
+ def get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ aget_assistants: Literal[True],
+ ) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
+ ...
+
+ @overload
+ def get_assistants(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI],
+ aget_assistants: Optional[Literal[False]],
+ ) -> SyncCursorPage[Assistant]:
+ ...
+
+ # fmt: on
+
def get_assistants(
self,
api_key: Optional[str],
@@ -1396,8 +2058,18 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
- client: Optional[OpenAI],
- ) -> SyncCursorPage[Assistant]:
+ client=None,
+ aget_assistants=None,
+ ):
+ if aget_assistants is not None and aget_assistants == True:
+ return self.async_get_assistants(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@@ -1413,18 +2085,95 @@ class OpenAIAssistantsAPI(BaseLLM):
### MESSAGES ###
- def add_message(
+ async def a_add_message(
self,
thread_id: str,
- message_data: MessageData,
+ message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
- client: Optional[OpenAI] = None,
+ client: Optional[AsyncOpenAI] = None,
) -> OpenAIMessage:
+ openai_client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+ thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
+ thread_id, **message_data # type: ignore
+ )
+
+ response_obj: Optional[OpenAIMessage] = None
+ if getattr(thread_message, "status", None) is None:
+ thread_message.status = "completed"
+ response_obj = OpenAIMessage(**thread_message.dict())
+ else:
+ response_obj = OpenAIMessage(**thread_message.dict())
+ return response_obj
+
+ # fmt: off
+
+ @overload
+ def add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ a_add_message: Literal[True],
+ ) -> Coroutine[None, None, OpenAIMessage]:
+ ...
+
+ @overload
+ def add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI],
+ a_add_message: Optional[Literal[False]],
+ ) -> OpenAIMessage:
+ ...
+
+ # fmt: on
+
+ def add_message(
+ self,
+ thread_id: str,
+ message_data: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client=None,
+ a_add_message: Optional[bool] = None,
+ ):
+ if a_add_message is not None and a_add_message == True:
+ return self.a_add_message(
+ thread_id=thread_id,
+ message_data=message_data,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@@ -1446,6 +2195,61 @@ class OpenAIAssistantsAPI(BaseLLM):
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
+ async def async_get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI] = None,
+ ) -> AsyncCursorPage[OpenAIMessage]:
+ openai_client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+
+ response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
+
+ return response
+
+ # fmt: off
+
+ @overload
+ def get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ aget_messages: Literal[True],
+ ) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
+ ...
+
+ @overload
+ def get_messages(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI],
+ aget_messages: Optional[Literal[False]],
+ ) -> SyncCursorPage[OpenAIMessage]:
+ ...
+
+ # fmt: on
+
def get_messages(
self,
thread_id: str,
@@ -1454,8 +2258,19 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
- client: Optional[OpenAI] = None,
- ) -> SyncCursorPage[OpenAIMessage]:
+ client=None,
+ aget_messages=None,
+ ):
+ if aget_messages is not None and aget_messages == True:
+ return self.async_get_messages(
+ thread_id=thread_id,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@@ -1471,6 +2286,70 @@ class OpenAIAssistantsAPI(BaseLLM):
### THREADS ###
+ async def async_create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ ) -> Thread:
+ openai_client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+
+ data = {}
+ if messages is not None:
+ data["messages"] = messages # type: ignore
+ if metadata is not None:
+ data["metadata"] = metadata # type: ignore
+
+ message_thread = await openai_client.beta.threads.create(**data) # type: ignore
+
+ return Thread(**message_thread.dict())
+
+ # fmt: off
+
+ @overload
+ def create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ client: Optional[AsyncOpenAI],
+ acreate_thread: Literal[True],
+ ) -> Coroutine[None, None, Thread]:
+ ...
+
+ @overload
+ def create_thread(
+ self,
+ metadata: Optional[dict],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
+ client: Optional[OpenAI],
+ acreate_thread: Optional[Literal[False]],
+ ) -> Thread:
+ ...
+
+ # fmt: on
+
def create_thread(
self,
metadata: Optional[dict],
@@ -1479,9 +2358,10 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
- client: Optional[OpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
- ) -> Thread:
+ client=None,
+ acreate_thread=None,
+ ):
"""
Here's an example:
```
@@ -1492,6 +2372,17 @@ class OpenAIAssistantsAPI(BaseLLM):
openai_api.create_thread(messages=[message])
```
"""
+ if acreate_thread is not None and acreate_thread == True:
+ return self.async_create_thread(
+ metadata=metadata,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ messages=messages,
+ )
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@@ -1511,6 +2402,61 @@ class OpenAIAssistantsAPI(BaseLLM):
return Thread(**message_thread.dict())
+ async def async_get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ ) -> Thread:
+ openai_client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+
+ response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
+
+ return Thread(**response.dict())
+
+ # fmt: off
+
+ @overload
+ def get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ aget_thread: Literal[True],
+ ) -> Coroutine[None, None, Thread]:
+ ...
+
+ @overload
+ def get_thread(
+ self,
+ thread_id: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI],
+ aget_thread: Optional[Literal[False]],
+ ) -> Thread:
+ ...
+
+ # fmt: on
+
def get_thread(
self,
thread_id: str,
@@ -1519,8 +2465,19 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
- client: Optional[OpenAI],
- ) -> Thread:
+ client=None,
+ aget_thread=None,
+ ):
+ if aget_thread is not None and aget_thread == True:
+ return self.async_get_thread(
+ thread_id=thread_id,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@@ -1539,6 +2496,142 @@ class OpenAIAssistantsAPI(BaseLLM):
### RUNS ###
+ async def arun_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[AsyncOpenAI],
+ ) -> Run:
+ openai_client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+
+ response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ )
+
+ return response
+
+ def async_run_thread_stream(
+ self,
+ client: AsyncOpenAI,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ tools: Optional[Iterable[AssistantToolParam]],
+ event_handler: Optional[AssistantEventHandler],
+ ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
+ data = {
+ "thread_id": thread_id,
+ "assistant_id": assistant_id,
+ "additional_instructions": additional_instructions,
+ "instructions": instructions,
+ "metadata": metadata,
+ "model": model,
+ "tools": tools,
+ }
+ if event_handler is not None:
+ data["event_handler"] = event_handler
+ return client.beta.threads.runs.stream(**data) # type: ignore
+
+ def run_thread_stream(
+ self,
+ client: OpenAI,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ tools: Optional[Iterable[AssistantToolParam]],
+ event_handler: Optional[AssistantEventHandler],
+ ) -> AssistantStreamManager[AssistantEventHandler]:
+ data = {
+ "thread_id": thread_id,
+ "assistant_id": assistant_id,
+ "additional_instructions": additional_instructions,
+ "instructions": instructions,
+ "metadata": metadata,
+ "model": model,
+ "tools": tools,
+ }
+ if event_handler is not None:
+ data["event_handler"] = event_handler
+ return client.beta.threads.runs.stream(**data) # type: ignore
+
+ # fmt: off
+
+ @overload
+ def run_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client,
+ arun_thread: Literal[True],
+ event_handler: Optional[AssistantEventHandler],
+ ) -> Coroutine[None, None, Run]:
+ ...
+
+ @overload
+ def run_thread(
+ self,
+ thread_id: str,
+ assistant_id: str,
+ additional_instructions: Optional[str],
+ instructions: Optional[str],
+ metadata: Optional[object],
+ model: Optional[str],
+ stream: Optional[bool],
+ tools: Optional[Iterable[AssistantToolParam]],
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client,
+ arun_thread: Optional[Literal[False]],
+ event_handler: Optional[AssistantEventHandler],
+ ) -> Run:
+ ...
+
+ # fmt: on
+
def run_thread(
self,
thread_id: str,
@@ -1554,8 +2647,47 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
- client: Optional[OpenAI],
- ) -> Run:
+ client=None,
+ arun_thread=None,
+ event_handler: Optional[AssistantEventHandler] = None,
+ ):
+ if arun_thread is not None and arun_thread == True:
+ if stream is not None and stream == True:
+ _client = self.async_get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
+ return self.async_run_thread_stream(
+ client=_client,
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ event_handler=event_handler,
+ )
+ return self.arun_thread(
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ stream=stream,
+ tools=tools,
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ )
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@@ -1565,6 +2697,19 @@ class OpenAIAssistantsAPI(BaseLLM):
client=client,
)
+ if stream is not None and stream == True:
+ return self.run_thread_stream(
+ client=openai_client,
+ thread_id=thread_id,
+ assistant_id=assistant_id,
+ additional_instructions=additional_instructions,
+ instructions=instructions,
+ metadata=metadata,
+ model=model,
+ tools=tools,
+ event_handler=event_handler,
+ )
+
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,
diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py
index 1e7e1d334..a3245cdac 100644
--- a/litellm/llms/predibase.py
+++ b/litellm/llms/predibase.py
@@ -1,7 +1,7 @@
# What is this?
## Controller file for Predibase Integration - https://predibase.com/
-
+from functools import partial
import os, types
import json
from enum import Enum
@@ -51,6 +51,32 @@ class PredibaseError(Exception):
) # Call the base class constructor with the parameters it needs
+async def make_call(
+ client: AsyncHTTPHandler,
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+):
+ response = await client.post(api_base, headers=headers, data=data, stream=True)
+
+ if response.status_code != 200:
+ raise PredibaseError(status_code=response.status_code, message=response.text)
+
+ completion_stream = response.aiter_lines()
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=completion_stream, # Pass the completion stream for logging
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
class PredibaseConfig:
"""
Reference: https://docs.predibase.com/user-guide/inference/rest_api
@@ -126,11 +152,17 @@ class PredibaseChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
- def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
+ def _validate_environment(
+ self, api_key: Optional[str], user_headers: dict, tenant_id: Optional[str]
+ ) -> dict:
if api_key is None:
raise ValueError(
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
)
+ if tenant_id is None:
+ raise ValueError(
+ "Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=)`) or in env - `PREDIBASE_TENANT_ID`."
+ )
headers = {
"content-type": "application/json",
"Authorization": "Bearer {}".format(api_key),
@@ -304,7 +336,7 @@ class PredibaseChatCompletion(BaseLLM):
logger_fn=None,
headers: dict = {},
) -> Union[ModelResponse, CustomStreamWrapper]:
- headers = self._validate_environment(api_key, headers)
+ headers = self._validate_environment(api_key, headers, tenant_id=tenant_id)
completion_url = ""
input_text = ""
base_url = "https://serving.app.predibase.com"
@@ -455,9 +487,16 @@ class PredibaseChatCompletion(BaseLLM):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
- response = await self.async_handler.post(
- api_base, headers=headers, data=json.dumps(data)
- )
+ try:
+ response = await self.async_handler.post(
+ api_base, headers=headers, data=json.dumps(data)
+ )
+ except httpx.HTTPStatusError as e:
+ raise PredibaseError(
+ status_code=e.response.status_code, message=e.response.text
+ )
+ except Exception as e:
+ raise PredibaseError(status_code=500, message=str(e))
return self.process_response(
model=model,
response=response,
@@ -488,26 +527,19 @@ class PredibaseChatCompletion(BaseLLM):
logger_fn=None,
headers={},
) -> CustomStreamWrapper:
- self.async_handler = AsyncHTTPHandler(
- timeout=httpx.Timeout(timeout=600.0, connect=5.0)
- )
data["stream"] = True
- response = await self.async_handler.post(
- url=api_base,
- headers=headers,
- data=json.dumps(data),
- stream=True,
- )
-
- if response.status_code != 200:
- raise PredibaseError(
- status_code=response.status_code, message=response.text
- )
-
- completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
- completion_stream=completion_stream,
+ completion_stream=None,
+ make_call=partial(
+ make_call,
+ api_base=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ ),
model=model,
custom_llm_provider="predibase",
logging_obj=logging_obj,
diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py
index cf593369c..41ecb486c 100644
--- a/litellm/llms/prompt_templates/factory.py
+++ b/litellm/llms/prompt_templates/factory.py
@@ -12,6 +12,7 @@ from typing import (
Sequence,
)
import litellm
+import litellm.types
from litellm.types.completion import (
ChatCompletionUserMessageParam,
ChatCompletionSystemMessageParam,
@@ -20,9 +21,12 @@ from litellm.types.completion import (
ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam,
)
+import litellm.types.llms
from litellm.types.llms.anthropic import *
import uuid
+import litellm.types.llms.vertex_ai
+
def default_pt(messages):
return " ".join(message["content"] for message in messages)
@@ -111,6 +115,26 @@ def llama_2_chat_pt(messages):
return prompt
+def convert_to_ollama_image(openai_image_url: str):
+ try:
+ if openai_image_url.startswith("http"):
+ openai_image_url = convert_url_to_base64(url=openai_image_url)
+
+ if openai_image_url.startswith("data:image/"):
+ # Extract the base64 image data
+ base64_data = openai_image_url.split("data:image/")[1].split(";base64,")[1]
+ else:
+ base64_data = openai_image_url
+
+ return base64_data
+ except Exception as e:
+ if "Error: Unable to fetch image from URL" in str(e):
+ raise e
+ raise Exception(
+ """Image url not in expected format. Example Expected input - "image_url": "data:image/jpeg;base64,{base64_image}". """
+ )
+
+
def ollama_pt(
model, messages
): # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
@@ -143,8 +167,10 @@ def ollama_pt(
if element["type"] == "text":
prompt += element["text"]
elif element["type"] == "image_url":
- image_url = element["image_url"]["url"]
- images.append(image_url)
+ base64_image = convert_to_ollama_image(
+ element["image_url"]["url"]
+ )
+ images.append(base64_image)
return {"prompt": prompt, "images": images}
else:
prompt = "".join(
@@ -841,6 +867,175 @@ def anthropic_messages_pt_xml(messages: list):
# ------------------------------------------------------------------------------
+def infer_protocol_value(
+ value: Any,
+) -> Literal[
+ "string_value",
+ "number_value",
+ "bool_value",
+ "struct_value",
+ "list_value",
+ "null_value",
+ "unknown",
+]:
+ if value is None:
+ return "null_value"
+ if isinstance(value, int) or isinstance(value, float):
+ return "number_value"
+ if isinstance(value, str):
+ return "string_value"
+ if isinstance(value, bool):
+ return "bool_value"
+ if isinstance(value, dict):
+ return "struct_value"
+ if isinstance(value, list):
+ return "list_value"
+
+ return "unknown"
+
+
+def convert_to_gemini_tool_call_invoke(
+ tool_calls: list,
+) -> List[litellm.types.llms.vertex_ai.PartType]:
+ """
+ OpenAI tool invokes:
+ {
+ "role": "assistant",
+ "content": null,
+ "tool_calls": [
+ {
+ "id": "call_abc123",
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "arguments": "{\n\"location\": \"Boston, MA\"\n}"
+ }
+ }
+ ]
+ },
+ """
+ """
+ Gemini tool call invokes: - https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#submit-api-output
+ content {
+ role: "model"
+ parts [
+ {
+ function_call {
+ name: "get_current_weather"
+ args {
+ fields {
+ key: "unit"
+ value {
+ string_value: "fahrenheit"
+ }
+ }
+ fields {
+ key: "predicted_temperature"
+ value {
+ number_value: 45
+ }
+ }
+ fields {
+ key: "location"
+ value {
+ string_value: "Boston, MA"
+ }
+ }
+ }
+ },
+ {
+ function_call {
+ name: "get_current_weather"
+ args {
+ fields {
+ key: "location"
+ value {
+ string_value: "San Francisco"
+ }
+ }
+ }
+ }
+ }
+ ]
+ }
+ """
+
+ """
+ - json.load the arguments
+ - iterate through arguments -> create a FunctionCallArgs for each field
+ """
+ try:
+ _parts_list: List[litellm.types.llms.vertex_ai.PartType] = []
+ for tool in tool_calls:
+ if "function" in tool:
+ name = tool["function"].get("name", "")
+ arguments = tool["function"].get("arguments", "")
+ arguments_dict = json.loads(arguments)
+ for k, v in arguments_dict.items():
+ inferred_protocol_value = infer_protocol_value(value=v)
+ _field = litellm.types.llms.vertex_ai.Field(
+ key=k, value={inferred_protocol_value: v}
+ )
+ _fields = litellm.types.llms.vertex_ai.FunctionCallArgs(
+ fields=_field
+ )
+ function_call = litellm.types.llms.vertex_ai.FunctionCall(
+ name=name,
+ args=_fields,
+ )
+ _parts_list.append(
+ litellm.types.llms.vertex_ai.PartType(function_call=function_call)
+ )
+ return _parts_list
+ except Exception as e:
+ raise Exception(
+ "Unable to convert openai tool calls={} to gemini tool calls. Received error={}".format(
+ tool_calls, str(e)
+ )
+ )
+
+
+def convert_to_gemini_tool_call_result(
+ message: dict,
+) -> litellm.types.llms.vertex_ai.PartType:
+ """
+ OpenAI message with a tool result looks like:
+ {
+ "tool_call_id": "tool_1",
+ "role": "tool",
+ "name": "get_current_weather",
+ "content": "function result goes here",
+ },
+
+ OpenAI message with a function call result looks like:
+ {
+ "role": "function",
+ "name": "get_current_weather",
+ "content": "function result goes here",
+ }
+ """
+ content = message.get("content", "")
+ name = message.get("name", "")
+
+ # We can't determine from openai message format whether it's a successful or
+ # error call result so default to the successful result template
+ inferred_content_value = infer_protocol_value(value=content)
+
+ _field = litellm.types.llms.vertex_ai.Field(
+ key="content", value={inferred_content_value: content}
+ )
+
+ _function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
+
+ _function_response = litellm.types.llms.vertex_ai.FunctionResponse(
+ name=name, response=_function_call_args
+ )
+
+ _part = litellm.types.llms.vertex_ai.PartType(function_response=_function_response)
+
+ return _part
+
+
def convert_to_anthropic_tool_result(message: dict) -> dict:
"""
OpenAI message with a tool result looks like:
@@ -1328,6 +1523,7 @@ def _gemini_vision_convert_messages(messages: list):
# Case 1: Image from URL
image = _load_image_from_url(img)
processed_images.append(image)
+
else:
try:
from PIL import Image
@@ -1335,8 +1531,23 @@ def _gemini_vision_convert_messages(messages: list):
raise Exception(
"gemini image conversion failed please run `pip install Pillow`"
)
- # Case 2: Image filepath (e.g. temp.jpeg) given
- image = Image.open(img)
+
+ if "base64" in img:
+ # Case 2: Base64 image data
+ import base64
+ import io
+
+ # Extract the base64 image data
+ base64_data = img.split("base64,")[1]
+
+ # Decode the base64 image data
+ image_data = base64.b64decode(base64_data)
+
+ # Load the image from the decoded data
+ image = Image.open(io.BytesIO(image_data))
+ else:
+ # Case 3: Image filepath (e.g. temp.jpeg) given
+ image = Image.open(img)
processed_images.append(image)
content = [prompt] + processed_images
return content
@@ -1513,7 +1724,7 @@ def prompt_factory(
elif custom_llm_provider == "clarifai":
if "claude" in model:
return anthropic_pt(messages=messages)
-
+
elif custom_llm_provider == "perplexity":
for message in messages:
message.pop("name", None)
diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py
index 386d24f59..ce62e51e9 100644
--- a/litellm/llms/replicate.py
+++ b/litellm/llms/replicate.py
@@ -251,7 +251,7 @@ async def async_handle_prediction_response(
logs = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
print_verbose(f"replicate: polling endpoint: {prediction_url}")
- await asyncio.sleep(0.5)
+ await asyncio.sleep(0.5) # prevent replicate rate limit errors
response = await http_handler.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py
index 84fec734f..5171b1efc 100644
--- a/litellm/llms/vertex_ai.py
+++ b/litellm/llms/vertex_ai.py
@@ -3,10 +3,15 @@ import json
from enum import Enum
import requests # type: ignore
import time
-from typing import Callable, Optional, Union, List
+from typing import Callable, Optional, Union, List, Literal, Any
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
+from litellm.types.llms.vertex_ai import *
+from litellm.llms.prompt_templates.factory import (
+ convert_to_gemini_tool_call_result,
+ convert_to_gemini_tool_call_invoke,
+)
class VertexAIError(Exception):
@@ -283,6 +288,139 @@ def _load_image_from_url(image_url: str):
return Image.from_bytes(data=image_bytes)
+def _convert_gemini_role(role: str) -> Literal["user", "model"]:
+ if role == "user":
+ return "user"
+ else:
+ return "model"
+
+
+def _process_gemini_image(image_url: str) -> PartType:
+ try:
+ if "gs://" in image_url:
+ # Case 1: Images with Cloud Storage URIs
+ # The supported MIME types for images include image/png and image/jpeg.
+ part_mime = "image/png" if "png" in image_url else "image/jpeg"
+ _file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
+ return PartType(file_data=_file_data)
+ elif "https:/" in image_url:
+ # Case 2: Images with direct links
+ image = _load_image_from_url(image_url)
+ _blob = BlobType(data=image.data, mime_type=image._mime_type)
+ return PartType(inline_data=_blob)
+ elif ".mp4" in image_url and "gs://" in image_url:
+ # Case 3: Videos with Cloud Storage URIs
+ part_mime = "video/mp4"
+ _file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
+ return PartType(file_data=_file_data)
+ elif "base64" in image_url:
+ # Case 4: Images with base64 encoding
+ import base64, re
+
+ # base 64 is passed as data:image/jpeg;base64,
+ image_metadata, img_without_base_64 = image_url.split(",")
+
+ # read mime_type from img_without_base_64=data:image/jpeg;base64
+ # Extract MIME type using regular expression
+ mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
+
+ if mime_type_match:
+ mime_type = mime_type_match.group(1)
+ else:
+ mime_type = "image/jpeg"
+ decoded_img = base64.b64decode(img_without_base_64)
+ _blob = BlobType(data=decoded_img, mime_type=mime_type)
+ return PartType(inline_data=_blob)
+ raise Exception("Invalid image received - {}".format(image_url))
+ except Exception as e:
+ raise e
+
+
+def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
+ """
+ Converts given messages from OpenAI format to Gemini format
+
+ - Parts must be iterable
+ - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
+ - Please ensure that function response turn comes immediately after a function call turn
+ """
+ user_message_types = {"user", "system"}
+ contents: List[ContentType] = []
+
+ msg_i = 0
+ while msg_i < len(messages):
+ user_content: List[PartType] = []
+ init_msg_i = msg_i
+ ## MERGE CONSECUTIVE USER CONTENT ##
+ while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
+ if isinstance(messages[msg_i]["content"], list):
+ _parts: List[PartType] = []
+ for element in messages[msg_i]["content"]:
+ if isinstance(element, dict):
+ if element["type"] == "text":
+ _part = PartType(text=element["text"])
+ _parts.append(_part)
+ elif element["type"] == "image_url":
+ image_url = element["image_url"]["url"]
+ _part = _process_gemini_image(image_url=image_url)
+ _parts.append(_part) # type: ignore
+ user_content.extend(_parts)
+ else:
+ _part = PartType(text=messages[msg_i]["content"])
+ user_content.append(_part)
+
+ msg_i += 1
+
+ if user_content:
+ contents.append(ContentType(role="user", parts=user_content))
+ assistant_content = []
+ ## MERGE CONSECUTIVE ASSISTANT CONTENT ##
+ while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
+ if isinstance(messages[msg_i]["content"], list):
+ _parts = []
+ for element in messages[msg_i]["content"]:
+ if isinstance(element, dict):
+ if element["type"] == "text":
+ _part = PartType(text=element["text"])
+ _parts.append(_part)
+ elif element["type"] == "image_url":
+ image_url = element["image_url"]["url"]
+ _part = _process_gemini_image(image_url=image_url)
+ _parts.append(_part) # type: ignore
+ assistant_content.extend(_parts)
+ elif messages[msg_i].get(
+ "tool_calls", []
+ ): # support assistant tool invoke convertion
+ assistant_content.extend(
+ convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"])
+ )
+ else:
+ assistant_text = (
+ messages[msg_i].get("content") or ""
+ ) # either string or none
+ if assistant_text:
+ assistant_content.append(PartType(text=assistant_text))
+
+ msg_i += 1
+
+ if assistant_content:
+ contents.append(ContentType(role="model", parts=assistant_content))
+
+ ## APPEND TOOL CALL MESSAGES ##
+ if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
+ _part = convert_to_gemini_tool_call_result(messages[msg_i])
+ contents.append(ContentType(parts=[_part])) # type: ignore
+ msg_i += 1
+ if msg_i == init_msg_i: # prevent infinite loops
+ raise Exception(
+ "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
+ messages[msg_i]
+ )
+ )
+
+ return contents
+
+
def _gemini_vision_convert_messages(messages: list):
"""
Converts given messages for GPT-4 Vision to Gemini format.
@@ -389,6 +527,19 @@ def _gemini_vision_convert_messages(messages: list):
raise e
+def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
+ _cache_key = f"{model}-{vertex_project}-{vertex_location}"
+ return _cache_key
+
+
+def _get_client_from_cache(client_cache_key: str):
+ return litellm.in_memory_llm_clients_cache.get(client_cache_key, None)
+
+
+def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any):
+ litellm.in_memory_llm_clients_cache[client_cache_key] = vertex_llm_model
+
+
def completion(
model: str,
messages: list,
@@ -396,10 +547,10 @@ def completion(
print_verbose: Callable,
encoding,
logging_obj,
+ optional_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
- optional_params=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
@@ -442,23 +593,32 @@ def completion(
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
- if vertex_credentials is not None and isinstance(vertex_credentials, str):
- import google.oauth2.service_account
- json_obj = json.loads(vertex_credentials)
+ _cache_key = _get_client_cache_key(
+ model=model, vertex_project=vertex_project, vertex_location=vertex_location
+ )
+ _vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
- creds = google.oauth2.service_account.Credentials.from_service_account_info(
- json_obj,
- scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ if _vertex_llm_model_object is None:
+ if vertex_credentials is not None and isinstance(vertex_credentials, str):
+ import google.oauth2.service_account
+
+ json_obj = json.loads(vertex_credentials)
+
+ creds = (
+ google.oauth2.service_account.Credentials.from_service_account_info(
+ json_obj,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ )
+ else:
+ creds, _ = google.auth.default(quota_project_id=vertex_project)
+ print_verbose(
+ f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
+ )
+ vertexai.init(
+ project=vertex_project, location=vertex_location, credentials=creds
)
- else:
- creds, _ = google.auth.default(quota_project_id=vertex_project)
- print_verbose(
- f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
- )
- vertexai.init(
- project=vertex_project, location=vertex_location, credentials=creds
- )
## Load Config
config = litellm.VertexAIConfig.get_config()
@@ -501,23 +661,27 @@ def completion(
model in litellm.vertex_language_models
or model in litellm.vertex_vision_models
):
- llm_model = GenerativeModel(model)
+ llm_model = _vertex_llm_model_object or GenerativeModel(model)
mode = "vision"
request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_chat_models:
- llm_model = ChatModel.from_pretrained(model)
+ llm_model = _vertex_llm_model_object or ChatModel.from_pretrained(model)
mode = "chat"
request_str += f"llm_model = ChatModel.from_pretrained({model})\n"
elif model in litellm.vertex_text_models:
- llm_model = TextGenerationModel.from_pretrained(model)
+ llm_model = _vertex_llm_model_object or TextGenerationModel.from_pretrained(
+ model
+ )
mode = "text"
request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n"
elif model in litellm.vertex_code_text_models:
- llm_model = CodeGenerationModel.from_pretrained(model)
+ llm_model = _vertex_llm_model_object or CodeGenerationModel.from_pretrained(
+ model
+ )
mode = "text"
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models
- llm_model = CodeChatModel.from_pretrained(model)
+ llm_model = _vertex_llm_model_object or CodeChatModel.from_pretrained(model)
mode = "chat"
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
elif model == "private":
@@ -556,6 +720,7 @@ def completion(
"model_response": model_response,
"encoding": encoding,
"messages": messages,
+ "request_str": request_str,
"print_verbose": print_verbose,
"client_options": client_options,
"instances": instances,
@@ -574,11 +739,9 @@ def completion(
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
- prompt, images = _gemini_vision_convert_messages(messages=messages)
- content = [prompt] + images
+ content = _gemini_convert_messages_with_history(messages=messages)
stream = optional_params.pop("stream", False)
if stream == True:
-
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call(
input=prompt,
@@ -589,7 +752,7 @@ def completion(
},
)
- model_response = llm_model.generate_content(
+ _model_response = llm_model.generate_content(
contents=content,
generation_config=optional_params,
safety_settings=safety_settings,
@@ -597,7 +760,7 @@ def completion(
tools=tools,
)
- return model_response
+ return _model_response
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
@@ -850,12 +1013,12 @@ async def async_completion(
mode: str,
prompt: str,
model: str,
+ messages: list,
model_response: ModelResponse,
- logging_obj=None,
- request_str=None,
+ request_str: str,
+ print_verbose: Callable,
+ logging_obj,
encoding=None,
- messages=None,
- print_verbose=None,
client_options=None,
instances=None,
vertex_project=None,
@@ -875,8 +1038,7 @@ async def async_completion(
tools = optional_params.pop("tools", None)
stream = optional_params.pop("stream", False)
- prompt, images = _gemini_vision_convert_messages(messages=messages)
- content = [prompt] + images
+ content = _gemini_convert_messages_with_history(messages=messages)
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
@@ -898,6 +1060,15 @@ async def async_completion(
tools=tools,
)
+ _cache_key = _get_client_cache_key(
+ model=model,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ )
+ _set_client_in_cache(
+ client_cache_key=_cache_key, vertex_llm_model=llm_model
+ )
+
if tools is not None and bool(
getattr(response.candidates[0].content.parts[0], "function_call", None)
):
@@ -1076,11 +1247,11 @@ async def async_streaming(
prompt: str,
model: str,
model_response: ModelResponse,
- logging_obj=None,
- request_str=None,
+ messages: list,
+ print_verbose: Callable,
+ logging_obj,
+ request_str: str,
encoding=None,
- messages=None,
- print_verbose=None,
client_options=None,
instances=None,
vertex_project=None,
@@ -1097,8 +1268,8 @@ async def async_streaming(
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
- prompt, images = _gemini_vision_convert_messages(messages=messages)
- content = [prompt] + images
+ content = _gemini_convert_messages_with_history(messages=messages)
+
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(
input=prompt,
diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py
index 3bdcf4fd6..065294280 100644
--- a/litellm/llms/vertex_ai_anthropic.py
+++ b/litellm/llms/vertex_ai_anthropic.py
@@ -35,7 +35,7 @@ class VertexAIError(Exception):
class VertexAIAnthropicConfig:
"""
- Reference: https://docs.anthropic.com/claude/reference/messages_post
+ Reference:https://docs.anthropic.com/claude/reference/messages_post
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py
new file mode 100644
index 000000000..b8c698c90
--- /dev/null
+++ b/litellm/llms/vertex_httpx.py
@@ -0,0 +1,224 @@
+import os, types
+import json
+from enum import Enum
+import requests # type: ignore
+import time
+from typing import Callable, Optional, Union, List, Any, Tuple
+from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
+import litellm, uuid
+import httpx, inspect # type: ignore
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from .base import BaseLLM
+
+
+class VertexAIError(Exception):
+ def __init__(self, status_code, message):
+ self.status_code = status_code
+ self.message = message
+ self.request = httpx.Request(
+ method="POST", url=" https://cloud.google.com/vertex-ai/"
+ )
+ self.response = httpx.Response(status_code=status_code, request=self.request)
+ super().__init__(
+ self.message
+ ) # Call the base class constructor with the parameters it needs
+
+
+class VertexLLM(BaseLLM):
+ def __init__(self) -> None:
+ super().__init__()
+ self.access_token: Optional[str] = None
+ self.refresh_token: Optional[str] = None
+ self._credentials: Optional[Any] = None
+ self.project_id: Optional[str] = None
+ self.async_handler: Optional[AsyncHTTPHandler] = None
+
+ def load_auth(self) -> Tuple[Any, str]:
+ from google.auth.transport.requests import Request # type: ignore[import-untyped]
+ from google.auth.credentials import Credentials # type: ignore[import-untyped]
+ import google.auth as google_auth
+
+ credentials, project_id = google_auth.default(
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+
+ credentials.refresh(Request())
+
+ if not project_id:
+ raise ValueError("Could not resolve project_id")
+
+ if not isinstance(project_id, str):
+ raise TypeError(
+ f"Expected project_id to be a str but got {type(project_id)}"
+ )
+
+ return credentials, project_id
+
+ def refresh_auth(self, credentials: Any) -> None:
+ from google.auth.transport.requests import Request # type: ignore[import-untyped]
+
+ credentials.refresh(Request())
+
+ def _prepare_request(self, request: httpx.Request) -> None:
+ access_token = self._ensure_access_token()
+
+ if request.headers.get("Authorization"):
+ # already authenticated, nothing for us to do
+ return
+
+ request.headers["Authorization"] = f"Bearer {access_token}"
+
+ def _ensure_access_token(self) -> str:
+ if self.access_token is not None:
+ return self.access_token
+
+ if not self._credentials:
+ self._credentials, project_id = self.load_auth()
+ if not self.project_id:
+ self.project_id = project_id
+ else:
+ self.refresh_auth(self._credentials)
+
+ if not self._credentials.token:
+ raise RuntimeError("Could not resolve API token from the environment")
+
+ assert isinstance(self._credentials.token, str)
+ return self._credentials.token
+
+ def image_generation(
+ self,
+ prompt: str,
+ vertex_project: str,
+ vertex_location: str,
+ model: Optional[
+ str
+ ] = "imagegeneration", # vertex ai uses imagegeneration as the default model
+ client: Optional[AsyncHTTPHandler] = None,
+ optional_params: Optional[dict] = None,
+ timeout: Optional[int] = None,
+ logging_obj=None,
+ model_response=None,
+ aimg_generation=False,
+ ):
+ if aimg_generation == True:
+ response = self.aimage_generation(
+ prompt=prompt,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ model=model,
+ client=client,
+ optional_params=optional_params,
+ timeout=timeout,
+ logging_obj=logging_obj,
+ model_response=model_response,
+ )
+ return response
+
+ async def aimage_generation(
+ self,
+ prompt: str,
+ vertex_project: str,
+ vertex_location: str,
+ model_response: litellm.ImageResponse,
+ model: Optional[
+ str
+ ] = "imagegeneration", # vertex ai uses imagegeneration as the default model
+ client: Optional[AsyncHTTPHandler] = None,
+ optional_params: Optional[dict] = None,
+ timeout: Optional[int] = None,
+ logging_obj=None,
+ ):
+ response = None
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
+ else:
+ self.async_handler = client # type: ignore
+
+ # make POST request to
+ # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
+
+ """
+ Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
+ curl -X POST \
+ -H "Authorization: Bearer $(gcloud auth print-access-token)" \
+ -H "Content-Type: application/json; charset=utf-8" \
+ -d {
+ "instances": [
+ {
+ "prompt": "a cat"
+ }
+ ],
+ "parameters": {
+ "sampleCount": 1
+ }
+ } \
+ "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
+ """
+ auth_header = self._ensure_access_token()
+ optional_params = optional_params or {
+ "sampleCount": 1
+ } # default optional params
+
+ request_data = {
+ "instances": [{"prompt": prompt}],
+ "parameters": optional_params,
+ }
+
+ request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ response = await self.async_handler.post(
+ url=url,
+ headers={
+ "Content-Type": "application/json; charset=utf-8",
+ "Authorization": f"Bearer {auth_header}",
+ },
+ data=json.dumps(request_data),
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+ """
+ Vertex AI Image generation response example:
+ {
+ "predictions": [
+ {
+ "bytesBase64Encoded": "BASE64_IMG_BYTES",
+ "mimeType": "image/png"
+ },
+ {
+ "mimeType": "image/png",
+ "bytesBase64Encoded": "BASE64_IMG_BYTES"
+ }
+ ]
+ }
+ """
+
+ _json_response = response.json()
+ _predictions = _json_response["predictions"]
+
+ _response_data: List[litellm.ImageObject] = []
+ for _prediction in _predictions:
+ _bytes_base64_encoded = _prediction["bytesBase64Encoded"]
+ image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded)
+ _response_data.append(image_object)
+
+ model_response.data = _response_data
+
+ return model_response
diff --git a/litellm/main.py b/litellm/main.py
index 2e4132a42..f76d6c521 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -73,12 +73,14 @@ from .llms import (
)
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.azure import AzureChatCompletion
+from .llms.databricks import DatabricksChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
+from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
prompt_factory,
@@ -90,6 +92,7 @@ import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict, Union, Mapping
from .caching import enable_cache, disable_cache, update_cache
+from .types.llms.openai import HttpxBinaryResponseContent
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
@@ -110,6 +113,7 @@ from litellm.utils import (
####### ENVIRONMENT VARIABLES ###################
openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion()
+databricks_chat_completions = DatabricksChatCompletion()
anthropic_chat_completions = AnthropicChatCompletion()
anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion()
@@ -118,6 +122,7 @@ huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
+vertex_chat_completion = VertexLLM()
####### COMPLETION ENDPOINTS ################
@@ -219,7 +224,7 @@ async def acompletion(
extra_headers: Optional[dict] = None,
# Optional liteLLM function params
**kwargs,
-):
+) -> Union[ModelResponse, CustomStreamWrapper]:
"""
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@@ -290,6 +295,7 @@ async def acompletion(
"api_version": api_version,
"api_key": api_key,
"model_list": model_list,
+ "extra_headers": extra_headers,
"acompletion": True, # assuming this is a required parameter
}
if custom_llm_provider is None:
@@ -326,13 +332,16 @@ async def acompletion(
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"
or custom_llm_provider == "predibase"
- or (custom_llm_provider == "bedrock" and "cohere" in model)
+ or custom_llm_provider == "bedrock"
+ or custom_llm_provider == "databricks"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, ModelResponse
): ## CACHING SCENARIO
+ if isinstance(init_response, dict):
+ response = ModelResponse(**init_response)
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
@@ -355,6 +364,7 @@ async def acompletion(
) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls)
return response
except Exception as e:
+ traceback.print_exc()
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model,
@@ -368,6 +378,8 @@ async def acompletion(
async def _async_streaming(response, model, custom_llm_provider, args):
try:
print_verbose(f"received response in _async_streaming: {response}")
+ if asyncio.iscoroutine(response):
+ response = await response
async for line in response:
print_verbose(f"line in async streaming: {line}")
yield line
@@ -413,6 +425,8 @@ def mock_completion(
api_key="mock-key",
)
if isinstance(mock_response, Exception):
+ if isinstance(mock_response, openai.APIError):
+ raise mock_response
raise litellm.APIError(
status_code=500, # type: ignore
message=str(mock_response),
@@ -420,6 +434,10 @@ def mock_completion(
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
+ time_delay = kwargs.get("mock_delay", None)
+ if time_delay is not None:
+ time.sleep(time_delay)
+
model_response = ModelResponse(stream=stream)
if stream is True:
# don't try to access stream object,
@@ -456,7 +474,9 @@ def mock_completion(
return model_response
- except:
+ except Exception as e:
+ if isinstance(e, openai.APIError):
+ raise e
traceback.print_exc()
raise Exception("Mock completion response failed")
@@ -482,7 +502,7 @@ def completion(
response_format: Optional[dict] = None,
seed: Optional[int] = None,
tools: Optional[List] = None,
- tool_choice: Optional[str] = None,
+ tool_choice: Optional[Union[str, dict]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
deployment_id=None,
@@ -668,6 +688,7 @@ def completion(
"region_name",
"allowed_model_region",
"model_config",
+ "fastest_response",
]
default_params = openai_params + litellm_params
@@ -817,6 +838,7 @@ def completion(
logprobs=logprobs,
top_logprobs=top_logprobs,
extra_headers=extra_headers,
+ api_version=api_version,
**non_default_params,
)
@@ -857,6 +879,7 @@ def completion(
user=user,
optional_params=optional_params,
litellm_params=litellm_params,
+ custom_llm_provider=custom_llm_provider,
)
if mock_response:
return mock_completion(
@@ -866,6 +889,7 @@ def completion(
mock_response=mock_response,
logging=logging,
acompletion=acompletion,
+ mock_delay=kwargs.get("mock_delay", None),
)
if custom_llm_provider == "azure":
# azure configs
@@ -1611,6 +1635,61 @@ def completion(
)
return response
response = model_response
+ elif custom_llm_provider == "databricks":
+ api_base = (
+ api_base # for databricks we check in get_llm_provider and pass in the api base from there
+ or litellm.api_base
+ or os.getenv("DATABRICKS_API_BASE")
+ )
+
+ # set API KEY
+ api_key = (
+ api_key
+ or litellm.api_key # for databricks we check in get_llm_provider and pass in the api key from there
+ or litellm.databricks_key
+ or get_secret("DATABRICKS_API_KEY")
+ )
+
+ headers = headers or litellm.headers
+
+ ## COMPLETION CALL
+ try:
+ response = databricks_chat_completions.completion(
+ model=model,
+ messages=messages,
+ headers=headers,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ api_key=api_key,
+ api_base=api_base,
+ acompletion=acompletion,
+ logging_obj=logging,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ timeout=timeout, # type: ignore
+ custom_prompt_dict=custom_prompt_dict,
+ client=client, # pass AsyncOpenAI, OpenAI client
+ encoding=encoding,
+ )
+ except Exception as e:
+ ## LOGGING - log the original exception returned
+ logging.post_call(
+ input=messages,
+ api_key=api_key,
+ original_response=str(e),
+ additional_args={"headers": headers},
+ )
+ raise e
+
+ if optional_params.get("stream", False):
+ ## LOGGING
+ logging.post_call(
+ input=messages,
+ api_key=api_key,
+ original_response=response,
+ additional_args={"headers": headers},
+ )
elif custom_llm_provider == "openrouter":
api_base = api_base or litellm.api_base or "https://openrouter.ai/api/v1"
@@ -1979,23 +2058,9 @@ def completion(
# boto3 reads keys from .env
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
- if "cohere" in model:
- response = bedrock_chat_completion.completion(
- model=model,
- messages=messages,
- custom_prompt_dict=litellm.custom_prompt_dict,
- model_response=model_response,
- print_verbose=print_verbose,
- optional_params=optional_params,
- litellm_params=litellm_params,
- logger_fn=logger_fn,
- encoding=encoding,
- logging_obj=logging,
- extra_headers=extra_headers,
- timeout=timeout,
- acompletion=acompletion,
- )
- else:
+ if (
+ "aws_bedrock_client" in optional_params
+ ): # use old bedrock flow for aws_bedrock_client users.
response = bedrock.completion(
model=model,
messages=messages,
@@ -2031,7 +2096,23 @@ def completion(
custom_llm_provider="bedrock",
logging_obj=logging,
)
-
+ else:
+ response = bedrock_chat_completion.completion(
+ model=model,
+ messages=messages,
+ custom_prompt_dict=custom_prompt_dict,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ encoding=encoding,
+ logging_obj=logging,
+ extra_headers=extra_headers,
+ timeout=timeout,
+ acompletion=acompletion,
+ client=client,
+ )
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
@@ -2334,6 +2415,7 @@ def completion(
"top_k": kwargs.get("top_k", 40),
},
},
+ verify=litellm.ssl_verify,
)
response_json = resp.json()
"""
@@ -2472,6 +2554,7 @@ def batch_completion(
list: A list of completion results.
"""
args = locals()
+
batch_messages = messages
completions = []
model = model
@@ -2525,7 +2608,15 @@ def batch_completion(
completions.append(future)
# Retrieve the results from the futures
- results = [future.result() for future in completions]
+ # results = [future.result() for future in completions]
+ # return exceptions if any
+ results = []
+ for future in completions:
+ try:
+ results.append(future.result())
+ except Exception as exc:
+ results.append(exc)
+
return results
@@ -2664,7 +2755,7 @@ def batch_completion_models_all_responses(*args, **kwargs):
### EMBEDDING ENDPOINTS ####################
@client
-async def aembedding(*args, **kwargs):
+async def aembedding(*args, **kwargs) -> EmbeddingResponse:
"""
Asynchronously calls the `embedding` function with the given arguments and keyword arguments.
@@ -2709,12 +2800,13 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"
+ or custom_llm_provider == "databricks"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
- if isinstance(init_response, dict) or isinstance(
- init_response, ModelResponse
- ): ## CACHING SCENARIO
+ if isinstance(init_response, dict):
+ response = EmbeddingResponse(**init_response)
+ elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
@@ -2754,7 +2846,7 @@ def embedding(
litellm_logging_obj=None,
logger_fn=None,
**kwargs,
-):
+) -> EmbeddingResponse:
"""
Embedding function that calls an API to generate embeddings for the given input.
@@ -2902,7 +2994,7 @@ def embedding(
)
try:
response = None
- logging = litellm_logging_obj
+ logging: Logging = litellm_logging_obj # type: ignore
logging.update_environment_variables(
model=model,
user=user,
@@ -2992,6 +3084,32 @@ def embedding(
client=client,
aembedding=aembedding,
)
+ elif custom_llm_provider == "databricks":
+ api_base = (
+ api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE")
+ ) # type: ignore
+
+ # set API KEY
+ api_key = (
+ api_key
+ or litellm.api_key
+ or litellm.databricks_key
+ or get_secret("DATABRICKS_API_KEY")
+ ) # type: ignore
+
+ ## EMBEDDING CALL
+ response = databricks_chat_completions.embedding(
+ model=model,
+ input=input,
+ api_base=api_base,
+ api_key=api_key,
+ logging_obj=logging,
+ timeout=timeout,
+ model_response=EmbeddingResponse(),
+ optional_params=optional_params,
+ client=client,
+ aembedding=aembedding,
+ )
elif custom_llm_provider == "cohere":
cohere_key = (
api_key
@@ -3607,7 +3725,7 @@ async def amoderation(input: str, model: str, api_key: Optional[str] = None, **k
##### Image Generation #######################
@client
-async def aimage_generation(*args, **kwargs):
+async def aimage_generation(*args, **kwargs) -> ImageResponse:
"""
Asynchronously calls the `image_generation` function with the given arguments and keyword arguments.
@@ -3640,6 +3758,8 @@ async def aimage_generation(*args, **kwargs):
if isinstance(init_response, dict) or isinstance(
init_response, ImageResponse
): ## CACHING SCENARIO
+ if isinstance(init_response, dict):
+ init_response = ImageResponse(**init_response)
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
@@ -3675,7 +3795,7 @@ def image_generation(
litellm_logging_obj=None,
custom_llm_provider=None,
**kwargs,
-):
+) -> ImageResponse:
"""
Maps the https://api.openai.com/v1/images/generations endpoint.
@@ -3851,6 +3971,36 @@ def image_generation(
model_response=model_response,
aimg_generation=aimg_generation,
)
+ elif custom_llm_provider == "vertex_ai":
+ vertex_ai_project = (
+ optional_params.pop("vertex_project", None)
+ or optional_params.pop("vertex_ai_project", None)
+ or litellm.vertex_project
+ or get_secret("VERTEXAI_PROJECT")
+ )
+ vertex_ai_location = (
+ optional_params.pop("vertex_location", None)
+ or optional_params.pop("vertex_ai_location", None)
+ or litellm.vertex_location
+ or get_secret("VERTEXAI_LOCATION")
+ )
+ vertex_credentials = (
+ optional_params.pop("vertex_credentials", None)
+ or optional_params.pop("vertex_ai_credentials", None)
+ or get_secret("VERTEXAI_CREDENTIALS")
+ )
+ model_response = vertex_chat_completion.image_generation(
+ model=model,
+ prompt=prompt,
+ timeout=timeout,
+ logging_obj=litellm_logging_obj,
+ optional_params=optional_params,
+ model_response=model_response,
+ vertex_project=vertex_ai_project,
+ vertex_location=vertex_ai_location,
+ aimg_generation=aimg_generation,
+ )
+
return model_response
except Exception as e:
## Map to OpenAI Exception
@@ -3977,7 +4127,7 @@ def transcription(
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_API_KEY")
- )
+ ) # type: ignore
response = azure_chat_completions.audio_transcriptions(
model=model,
@@ -3994,6 +4144,24 @@ def transcription(
max_retries=max_retries,
)
elif custom_llm_provider == "openai":
+ api_base = (
+ api_base
+ or litellm.api_base
+ or get_secret("OPENAI_API_BASE")
+ or "https://api.openai.com/v1"
+ ) # type: ignore
+ openai.organization = (
+ litellm.organization
+ or get_secret("OPENAI_ORGANIZATION")
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
+ )
+ # set API KEY
+ api_key = (
+ api_key
+ or litellm.api_key
+ or litellm.openai_key
+ or get_secret("OPENAI_API_KEY")
+ ) # type: ignore
response = openai_chat_completions.audio_transcriptions(
model=model,
audio_file=file,
@@ -4003,6 +4171,139 @@ def transcription(
timeout=timeout,
logging_obj=litellm_logging_obj,
max_retries=max_retries,
+ api_base=api_base,
+ api_key=api_key,
+ )
+ return response
+
+
+@client
+async def aspeech(*args, **kwargs) -> HttpxBinaryResponseContent:
+ """
+ Calls openai tts endpoints.
+ """
+ loop = asyncio.get_event_loop()
+ model = args[0] if len(args) > 0 else kwargs["model"]
+ ### PASS ARGS TO Image Generation ###
+ kwargs["aspeech"] = True
+ custom_llm_provider = kwargs.get("custom_llm_provider", None)
+ try:
+ # Use a partial function to pass your keyword arguments
+ func = partial(speech, *args, **kwargs)
+
+ # Add the context to the function
+ ctx = contextvars.copy_context()
+ func_with_context = partial(ctx.run, func)
+
+ _, custom_llm_provider, _, _ = get_llm_provider(
+ model=model, api_base=kwargs.get("api_base", None)
+ )
+
+ # Await normally
+ init_response = await loop.run_in_executor(None, func_with_context)
+ if asyncio.iscoroutine(init_response):
+ response = await init_response
+ else:
+ # Call the synchronous function using run_in_executor
+ response = await loop.run_in_executor(None, func_with_context)
+ return response # type: ignore
+ except Exception as e:
+ custom_llm_provider = custom_llm_provider or "openai"
+ raise exception_type(
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ original_exception=e,
+ completion_kwargs=args,
+ extra_kwargs=kwargs,
+ )
+
+
+@client
+def speech(
+ model: str,
+ input: str,
+ voice: str,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ organization: Optional[str] = None,
+ project: Optional[str] = None,
+ max_retries: Optional[int] = None,
+ metadata: Optional[dict] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ response_format: Optional[str] = None,
+ speed: Optional[int] = None,
+ client=None,
+ headers: Optional[dict] = None,
+ custom_llm_provider: Optional[str] = None,
+ aspeech: Optional[bool] = None,
+ **kwargs,
+) -> HttpxBinaryResponseContent:
+
+ model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
+
+ optional_params = {}
+ if response_format is not None:
+ optional_params["response_format"] = response_format
+ if speed is not None:
+ optional_params["speed"] = speed # type: ignore
+
+ if timeout is None:
+ timeout = litellm.request_timeout
+
+ if max_retries is None:
+ max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
+ response: Optional[HttpxBinaryResponseContent] = None
+ if custom_llm_provider == "openai":
+ api_base = (
+ api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
+ or litellm.api_base
+ or get_secret("OPENAI_API_BASE")
+ or "https://api.openai.com/v1"
+ ) # type: ignore
+ # set API KEY
+ api_key = (
+ api_key
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
+ or litellm.openai_key
+ or get_secret("OPENAI_API_KEY")
+ ) # type: ignore
+
+ organization = (
+ organization
+ or litellm.organization
+ or get_secret("OPENAI_ORGANIZATION")
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
+ ) # type: ignore
+
+ project = (
+ project
+ or litellm.project
+ or get_secret("OPENAI_PROJECT")
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
+ ) # type: ignore
+
+ headers = headers or litellm.headers
+
+ response = openai_chat_completions.audio_speech(
+ model=model,
+ input=input,
+ voice=voice,
+ optional_params=optional_params,
+ api_key=api_key,
+ api_base=api_base,
+ organization=organization,
+ project=project,
+ max_retries=max_retries,
+ timeout=timeout,
+ client=client, # pass AsyncOpenAI, OpenAI client
+ aspeech=aspeech,
+ )
+
+ if response is None:
+ raise Exception(
+ "Unable to map the custom llm provider={} to a known provider={}.".format(
+ custom_llm_provider, litellm.provider_list
+ )
)
return response
@@ -4035,6 +4336,10 @@ async def ahealth_check(
mode = litellm.model_cost[model]["mode"]
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
+
+ if model in litellm.model_cost and mode is None:
+ mode = litellm.model_cost[model]["mode"]
+
mode = mode or "chat" # default to chat completion calls
if custom_llm_provider == "azure":
@@ -4231,7 +4536,7 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
def stream_chunk_builder(
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
-):
+) -> Union[ModelResponse, TextCompletionResponse]:
model_response = litellm.ModelResponse()
### SORT CHUNKS BASED ON CREATED ORDER ##
print_verbose("Goes into checking if chunk has hiddden created at param")
diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json
index f3db33c60..3fe089a6b 100644
--- a/litellm/model_prices_and_context_window_backup.json
+++ b/litellm/model_prices_and_context_window_backup.json
@@ -380,6 +380,18 @@
"output_cost_per_second": 0.0001,
"litellm_provider": "azure"
},
+ "azure/gpt-4o": {
+ "max_tokens": 4096,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.000005,
+ "output_cost_per_token": 0.000015,
+ "litellm_provider": "azure",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_parallel_function_calling": true,
+ "supports_vision": true
+ },
"azure/gpt-4-turbo-2024-04-09": {
"max_tokens": 4096,
"max_input_tokens": 128000,
@@ -518,8 +530,8 @@
"max_tokens": 4096,
"max_input_tokens": 4097,
"max_output_tokens": 4096,
- "input_cost_per_token": 0.0000015,
- "output_cost_per_token": 0.000002,
+ "input_cost_per_token": 0.0000005,
+ "output_cost_per_token": 0.0000015,
"litellm_provider": "azure",
"mode": "chat",
"supports_function_calling": true
@@ -692,8 +704,8 @@
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
- "input_cost_per_token": 0.00000015,
- "output_cost_per_token": 0.00000046,
+ "input_cost_per_token": 0.00000025,
+ "output_cost_per_token": 0.00000025,
"litellm_provider": "mistral",
"mode": "chat"
},
@@ -701,8 +713,8 @@
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
- "input_cost_per_token": 0.000002,
- "output_cost_per_token": 0.000006,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000003,
"litellm_provider": "mistral",
"supports_function_calling": true,
"mode": "chat"
@@ -711,8 +723,8 @@
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
- "input_cost_per_token": 0.000002,
- "output_cost_per_token": 0.000006,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000003,
"litellm_provider": "mistral",
"supports_function_calling": true,
"mode": "chat"
@@ -748,8 +760,8 @@
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
- "input_cost_per_token": 0.000008,
- "output_cost_per_token": 0.000024,
+ "input_cost_per_token": 0.000004,
+ "output_cost_per_token": 0.000012,
"litellm_provider": "mistral",
"mode": "chat",
"supports_function_calling": true
@@ -758,26 +770,63 @@
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
- "input_cost_per_token": 0.000008,
- "output_cost_per_token": 0.000024,
+ "input_cost_per_token": 0.000004,
+ "output_cost_per_token": 0.000012,
"litellm_provider": "mistral",
"mode": "chat",
"supports_function_calling": true
},
+ "mistral/open-mistral-7b": {
+ "max_tokens": 8191,
+ "max_input_tokens": 32000,
+ "max_output_tokens": 8191,
+ "input_cost_per_token": 0.00000025,
+ "output_cost_per_token": 0.00000025,
+ "litellm_provider": "mistral",
+ "mode": "chat"
+ },
"mistral/open-mixtral-8x7b": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
+ "input_cost_per_token": 0.0000007,
+ "output_cost_per_token": 0.0000007,
+ "litellm_provider": "mistral",
+ "mode": "chat",
+ "supports_function_calling": true
+ },
+ "mistral/open-mixtral-8x22b": {
+ "max_tokens": 8191,
+ "max_input_tokens": 64000,
+ "max_output_tokens": 8191,
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000006,
"litellm_provider": "mistral",
"mode": "chat",
"supports_function_calling": true
},
+ "mistral/codestral-latest": {
+ "max_tokens": 8191,
+ "max_input_tokens": 32000,
+ "max_output_tokens": 8191,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000003,
+ "litellm_provider": "mistral",
+ "mode": "chat"
+ },
+ "mistral/codestral-2405": {
+ "max_tokens": 8191,
+ "max_input_tokens": 32000,
+ "max_output_tokens": 8191,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000003,
+ "litellm_provider": "mistral",
+ "mode": "chat"
+ },
"mistral/mistral-embed": {
"max_tokens": 8192,
"max_input_tokens": 8192,
- "input_cost_per_token": 0.000000111,
+ "input_cost_per_token": 0.0000001,
"litellm_provider": "mistral",
"mode": "embedding"
},
@@ -1128,6 +1177,24 @@
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
+ "gemini-1.5-flash-001": {
+ "max_tokens": 8192,
+ "max_input_tokens": 1000000,
+ "max_output_tokens": 8192,
+ "max_images_per_prompt": 3000,
+ "max_videos_per_prompt": 10,
+ "max_video_length": 1,
+ "max_audio_length_hours": 8.4,
+ "max_audio_per_prompt": 1,
+ "max_pdf_size_mb": 30,
+ "input_cost_per_token": 0,
+ "output_cost_per_token": 0,
+ "litellm_provider": "vertex_ai-language-models",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_vision": true,
+ "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
+ },
"gemini-1.5-flash-preview-0514": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
@@ -1146,6 +1213,18 @@
"supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
+ "gemini-1.5-pro-001": {
+ "max_tokens": 8192,
+ "max_input_tokens": 1000000,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.000000625,
+ "output_cost_per_token": 0.000001875,
+ "litellm_provider": "vertex_ai-language-models",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": true,
+ "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
+ },
"gemini-1.5-pro-preview-0514": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
@@ -1265,13 +1344,19 @@
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
- "input_cost_per_token": 0.0000015,
- "output_cost_per_token": 0.0000075,
+ "input_cost_per_token": 0.000015,
+ "output_cost_per_token": 0.000075,
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
+ "vertex_ai/imagegeneration@006": {
+ "cost_per_image": 0.020,
+ "litellm_provider": "vertex_ai-image-models",
+ "mode": "image_generation",
+ "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
+ },
"textembedding-gecko": {
"max_tokens": 3072,
"max_input_tokens": 3072,
@@ -1415,7 +1500,7 @@
"max_pdf_size_mb": 30,
"input_cost_per_token": 0,
"output_cost_per_token": 0,
- "litellm_provider": "vertex_ai-language-models",
+ "litellm_provider": "gemini",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
@@ -1599,36 +1684,36 @@
"mode": "chat"
},
"replicate/meta/llama-3-70b": {
- "max_tokens": 4096,
- "max_input_tokens": 4096,
- "max_output_tokens": 4096,
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
"input_cost_per_token": 0.00000065,
"output_cost_per_token": 0.00000275,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-3-70b-instruct": {
- "max_tokens": 4096,
- "max_input_tokens": 4096,
- "max_output_tokens": 4096,
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
"input_cost_per_token": 0.00000065,
"output_cost_per_token": 0.00000275,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-3-8b": {
- "max_tokens": 4096,
- "max_input_tokens": 4096,
- "max_output_tokens": 4096,
+ "max_tokens": 8086,
+ "max_input_tokens": 8086,
+ "max_output_tokens": 8086,
"input_cost_per_token": 0.00000005,
"output_cost_per_token": 0.00000025,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-3-8b-instruct": {
- "max_tokens": 4096,
- "max_input_tokens": 4096,
- "max_output_tokens": 4096,
+ "max_tokens": 8086,
+ "max_input_tokens": 8086,
+ "max_output_tokens": 8086,
"input_cost_per_token": 0.00000005,
"output_cost_per_token": 0.00000025,
"litellm_provider": "replicate",
@@ -1892,7 +1977,7 @@
"mode": "chat"
},
"openrouter/meta-llama/codellama-34b-instruct": {
- "max_tokens": 8096,
+ "max_tokens": 8192,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000005,
"litellm_provider": "openrouter",
@@ -3384,9 +3469,10 @@
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat",
- "supports_function_calling": true
+ "supports_function_calling": true,
+ "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mistral-7B-Instruct-v0.1"
},
- "anyscale/Mixtral-8x7B-Instruct-v0.1": {
+ "anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1": {
"max_tokens": 16384,
"max_input_tokens": 16384,
"max_output_tokens": 16384,
@@ -3394,7 +3480,19 @@
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat",
- "supports_function_calling": true
+ "supports_function_calling": true,
+ "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mixtral-8x7B-Instruct-v0.1"
+ },
+ "anyscale/mistralai/Mixtral-8x22B-Instruct-v0.1": {
+ "max_tokens": 65536,
+ "max_input_tokens": 65536,
+ "max_output_tokens": 65536,
+ "input_cost_per_token": 0.00000090,
+ "output_cost_per_token": 0.00000090,
+ "litellm_provider": "anyscale",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mixtral-8x22B-Instruct-v0.1"
},
"anyscale/HuggingFaceH4/zephyr-7b-beta": {
"max_tokens": 16384,
@@ -3405,6 +3503,16 @@
"litellm_provider": "anyscale",
"mode": "chat"
},
+ "anyscale/google/gemma-7b-it": {
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.00000015,
+ "output_cost_per_token": 0.00000015,
+ "litellm_provider": "anyscale",
+ "mode": "chat",
+ "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/google-gemma-7b-it"
+ },
"anyscale/meta-llama/Llama-2-7b-chat-hf": {
"max_tokens": 4096,
"max_input_tokens": 4096,
@@ -3441,6 +3549,36 @@
"litellm_provider": "anyscale",
"mode": "chat"
},
+ "anyscale/codellama/CodeLlama-70b-Instruct-hf": {
+ "max_tokens": 4096,
+ "max_input_tokens": 4096,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000001,
+ "litellm_provider": "anyscale",
+ "mode": "chat",
+ "source" : "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/codellama-CodeLlama-70b-Instruct-hf"
+ },
+ "anyscale/meta-llama/Meta-Llama-3-8B-Instruct": {
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.00000015,
+ "output_cost_per_token": 0.00000015,
+ "litellm_provider": "anyscale",
+ "mode": "chat",
+ "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/meta-llama-Meta-Llama-3-8B-Instruct"
+ },
+ "anyscale/meta-llama/Meta-Llama-3-70B-Instruct": {
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.00000100,
+ "output_cost_per_token": 0.00000100,
+ "litellm_provider": "anyscale",
+ "mode": "chat",
+ "source" : "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/meta-llama-Meta-Llama-3-70B-Instruct"
+ },
"cloudflare/@cf/meta/llama-2-7b-chat-fp16": {
"max_tokens": 3072,
"max_input_tokens": 3072,
@@ -3532,6 +3670,76 @@
"output_cost_per_token": 0.000000,
"litellm_provider": "voyage",
"mode": "embedding"
- }
+ },
+ "databricks/databricks-dbrx-instruct": {
+ "max_tokens": 32768,
+ "max_input_tokens": 32768,
+ "max_output_tokens": 32768,
+ "input_cost_per_token": 0.00000075,
+ "output_cost_per_token": 0.00000225,
+ "litellm_provider": "databricks",
+ "mode": "chat",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ },
+ "databricks/databricks-meta-llama-3-70b-instruct": {
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000003,
+ "litellm_provider": "databricks",
+ "mode": "chat",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ },
+ "databricks/databricks-llama-2-70b-chat": {
+ "max_tokens": 4096,
+ "max_input_tokens": 4096,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.0000005,
+ "output_cost_per_token": 0.0000015,
+ "litellm_provider": "databricks",
+ "mode": "chat",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ },
+ "databricks/databricks-mixtral-8x7b-instruct": {
+ "max_tokens": 4096,
+ "max_input_tokens": 4096,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.0000005,
+ "output_cost_per_token": 0.000001,
+ "litellm_provider": "databricks",
+ "mode": "chat",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ },
+ "databricks/databricks-mpt-30b-instruct": {
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000001,
+ "litellm_provider": "databricks",
+ "mode": "chat",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ },
+ "databricks/databricks-mpt-7b-instruct": {
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.0000005,
+ "output_cost_per_token": 0.0000005,
+ "litellm_provider": "databricks",
+ "mode": "chat",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ },
+ "databricks/databricks-bge-large-en": {
+ "max_tokens": 512,
+ "max_input_tokens": 512,
+ "output_vector_size": 1024,
+ "input_cost_per_token": 0.0000001,
+ "output_cost_per_token": 0.0,
+ "litellm_provider": "databricks",
+ "mode": "embedding",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ }
}
diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html
index 3e58fe524..41cc292f2 100644
--- a/litellm/proxy/_experimental/out/404.html
+++ b/litellm/proxy/_experimental/out/404.html
@@ -1 +1 @@
-404: This page could not be found.LiteLLM Dashboard
404
This page could not be found.
\ No newline at end of file
+404: This page could not be found.LiteLLM Dashboard