+
+ import openai
+ client = openai.OpenAI(
+ api_key="{key_token}",
+ base_url={os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")}
+ )
+
+ response = client.chat.completions.create(
+ model="gpt-3.5-turbo", # model to send to the proxy
+ messages = [
+ {{
+ "role": "user",
+ "content": "this is a test request, write a short poem"
+ }}
+ ]
+ )
+
+
+
+
+ If you have any questions, please send an email to {email_support_contact}
+
+ Best,
+ The LiteLLM team
+ """
+
+ payload = webhook_event.model_dump_json()
+ email_event = {
+ "to": recipient_email,
+ "subject": f"LiteLLM: {event_name}",
+ "html": email_html_content,
+ }
+
+ response = await send_email(
+ receiver_email=email_event["to"],
+ subject=email_event["subject"],
+ html=email_event["html"],
+ )
+
+ return False
+
+ async def send_email_alert_using_smtp(self, webhook_event: WebhookEvent) -> bool:
+ """
+ Sends structured Email alert to an SMTP server
+
+ Currently only implemented for budget alerts
+
+ Returns -> True if sent, False if not.
+ """
+ from litellm.proxy.utils import send_email
+
+ from litellm.proxy.proxy_server import premium_user, prisma_client
+
+ email_logo_url = os.getenv("SMTP_SENDER_LOGO", None)
+ email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
+ await self._check_if_using_premium_email_feature(
+ premium_user, email_logo_url, email_support_contact
+ )
+
+ if email_logo_url is None:
+ email_logo_url = LITELLM_LOGO_URL
+ if email_support_contact is None:
+ email_support_contact = LITELLM_SUPPORT_CONTACT
+
+ event_name = webhook_event.event_message
+ recipient_email = webhook_event.user_email
+ user_name = webhook_event.user_id
+ max_budget = webhook_event.max_budget
+ email_html_content = "Alert from LiteLLM Server"
+ if recipient_email is None:
+ verbose_proxy_logger.error(
+ "Trying to send email alert to no recipient", extra=webhook_event.dict()
+ )
+
+ if webhook_event.event == "budget_crossed":
+ email_html_content = f"""
+
+
+
Hi {user_name},
+
+ Your LLM API usage this month has reached your account's monthly budget of ${max_budget}
+
+ API requests will be rejected until either (a) you increase your monthly budget or (b) your monthly usage resets at the beginning of the next calendar month.
+
+ If you have any questions, please send an email to {email_support_contact}
+
+ Best,
+ The LiteLLM team
+ """
+
+ payload = webhook_event.model_dump_json()
+ email_event = {
+ "to": recipient_email,
+ "subject": f"LiteLLM: {event_name}",
+ "html": email_html_content,
+ }
+
+ response = await send_email(
+ receiver_email=email_event["to"],
+ subject=email_event["subject"],
+ html=email_event["html"],
+ )
+
+ return False
+
async def send_alert(
self,
message: str,
level: Literal["Low", "Medium", "High"],
- alert_type: Literal[
- "llm_exceptions",
- "llm_too_slow",
- "llm_requests_hanging",
- "budget_alerts",
- "db_exceptions",
- "daily_reports",
- "new_model_added",
- "cooldown_deployment",
- ],
+ alert_type: Literal[AlertType],
+ user_info: Optional[WebhookEvent] = None,
**kwargs,
):
"""
@@ -748,6 +1379,27 @@ Model Info:
if self.alerting is None:
return
+ if (
+ "webhook" in self.alerting
+ and alert_type == "budget_alerts"
+ and user_info is not None
+ ):
+ await self.send_webhook_alert(webhook_event=user_info)
+
+ if (
+ "email" in self.alerting
+ and alert_type == "budget_alerts"
+ and user_info is not None
+ ):
+ # only send budget alerts over Email
+ await self.send_email_alert_using_smtp(webhook_event=user_info)
+
+ if "slack" not in self.alerting:
+ return
+
+ if alert_type not in self.alert_types:
+ return
+
from datetime import datetime
import json
@@ -791,46 +1443,78 @@ Model Info:
if response.status_code == 200:
pass
else:
- print("Error sending slack alert. Error=", response.text) # noqa
+ verbose_proxy_logger.debug(
+ "Error sending slack alert. Error=", response.text
+ )
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""Log deployment latency"""
- if "daily_reports" in self.alert_types:
- model_id = (
- kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
- )
- response_s: timedelta = end_time - start_time
-
- final_value = response_s
- total_tokens = 0
-
- if isinstance(response_obj, litellm.ModelResponse):
- completion_tokens = response_obj.usage.completion_tokens
- final_value = float(response_s.total_seconds() / completion_tokens)
-
- await self.async_update_daily_reports(
- DeploymentMetrics(
- id=model_id,
- failed_request=False,
- latency_per_output_token=final_value,
- updated_at=litellm.utils.get_utc_datetime(),
+ try:
+ if "daily_reports" in self.alert_types:
+ model_id = (
+ kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
)
+ response_s: timedelta = end_time - start_time
+
+ final_value = response_s
+ total_tokens = 0
+
+ if isinstance(response_obj, litellm.ModelResponse):
+ completion_tokens = response_obj.usage.completion_tokens
+ if completion_tokens is not None and completion_tokens > 0:
+ final_value = float(
+ response_s.total_seconds() / completion_tokens
+ )
+ if isinstance(final_value, timedelta):
+ final_value = final_value.total_seconds()
+
+ await self.async_update_daily_reports(
+ DeploymentMetrics(
+ id=model_id,
+ failed_request=False,
+ latency_per_output_token=final_value,
+ updated_at=litellm.utils.get_utc_datetime(),
+ )
+ )
+ except Exception as e:
+ verbose_proxy_logger.error(
+ "[Non-Blocking Error] Slack Alerting: Got error in logging LLM deployment latency: ",
+ e,
)
+ pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""Log failure + deployment latency"""
- if "daily_reports" in self.alert_types:
- model_id = (
- kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
- )
- await self.async_update_daily_reports(
- DeploymentMetrics(
- id=model_id,
- failed_request=True,
- latency_per_output_token=None,
- updated_at=litellm.utils.get_utc_datetime(),
- )
- )
+ _litellm_params = kwargs.get("litellm_params", {})
+ _model_info = _litellm_params.get("model_info", {}) or {}
+ model_id = _model_info.get("id", "")
+ try:
+ if "daily_reports" in self.alert_types:
+ try:
+ await self.async_update_daily_reports(
+ DeploymentMetrics(
+ id=model_id,
+ failed_request=True,
+ latency_per_output_token=None,
+ updated_at=litellm.utils.get_utc_datetime(),
+ )
+ )
+ except Exception as e:
+ verbose_logger.debug(f"Exception raises -{str(e)}")
+
+ if isinstance(kwargs.get("exception", ""), APIError):
+ if "outage_alerts" in self.alert_types:
+ await self.outage_alerts(
+ exception=kwargs["exception"],
+ deployment_id=model_id,
+ )
+
+ if "region_outage_alerts" in self.alert_types:
+ await self.region_outage_alerts(
+ exception=kwargs["exception"], deployment_id=model_id
+ )
+ except Exception as e:
+ pass
async def _run_scheduler_helper(self, llm_router) -> bool:
"""
@@ -842,40 +1526,26 @@ Model Info:
report_sent = await self.internal_usage_cache.async_get_cache(
key=SlackAlertingCacheKeys.report_sent_key.value
- ) # None | datetime
+ ) # None | float
- current_time = litellm.utils.get_utc_datetime()
+ current_time = time.time()
if report_sent is None:
- _current_time = current_time.isoformat()
await self.internal_usage_cache.async_set_cache(
key=SlackAlertingCacheKeys.report_sent_key.value,
- value=_current_time,
+ value=current_time,
)
- else:
+ elif isinstance(report_sent, float):
# Check if current time - interval >= time last sent
- delta_naive = timedelta(seconds=self.alerting_args.daily_report_frequency)
- if isinstance(report_sent, str):
- report_sent = dt.fromisoformat(report_sent)
+ interval_seconds = self.alerting_args.daily_report_frequency
- # Ensure report_sent is an aware datetime object
- if report_sent.tzinfo is None:
- report_sent = report_sent.replace(tzinfo=timezone.utc)
-
- # Calculate delta as an aware datetime object with the same timezone as report_sent
- delta = report_sent - delta_naive
-
- current_time_utc = current_time.astimezone(timezone.utc)
- delta_utc = delta.astimezone(timezone.utc)
-
- if current_time_utc >= delta_utc:
+ if current_time - report_sent >= interval_seconds:
# Sneak in the reporting logic here
await self.send_daily_reports(router=llm_router)
# Also, don't forget to update the report_sent time after sending the report!
- _current_time = current_time.isoformat()
await self.internal_usage_cache.async_set_cache(
key=SlackAlertingCacheKeys.report_sent_key.value,
- value=_current_time,
+ value=current_time,
)
report_sent_bool = True
@@ -942,7 +1612,7 @@ Model Info:
await self.send_alert(
message=_weekly_spend_message,
level="Low",
- alert_type="daily_reports",
+ alert_type="spend_reports",
)
except Exception as e:
verbose_proxy_logger.error("Error sending weekly spend report", e)
@@ -993,7 +1663,7 @@ Model Info:
await self.send_alert(
message=_spend_message,
level="Low",
- alert_type="daily_reports",
+ alert_type="spend_reports",
)
except Exception as e:
verbose_proxy_logger.error("Error sending weekly spend report", e)
diff --git a/litellm/integrations/traceloop.py b/litellm/integrations/traceloop.py
index bbdb9a1b0..39d62028e 100644
--- a/litellm/integrations/traceloop.py
+++ b/litellm/integrations/traceloop.py
@@ -1,114 +1,153 @@
+import traceback
+from litellm._logging import verbose_logger
+import litellm
+
+
class TraceloopLogger:
def __init__(self):
- from traceloop.sdk.tracing.tracing import TracerWrapper
- from traceloop.sdk import Traceloop
+ try:
+ from traceloop.sdk.tracing.tracing import TracerWrapper
+ from traceloop.sdk import Traceloop
+ from traceloop.sdk.instruments import Instruments
+ except ModuleNotFoundError as e:
+ verbose_logger.error(
+ f"Traceloop not installed, try running 'pip install traceloop-sdk' to fix this error: {e}\n{traceback.format_exc()}"
+ )
- Traceloop.init(app_name="Litellm-Server", disable_batch=True)
+ Traceloop.init(
+ app_name="Litellm-Server",
+ disable_batch=True,
+ instruments=[
+ Instruments.CHROMA,
+ Instruments.PINECONE,
+ Instruments.WEAVIATE,
+ Instruments.LLAMA_INDEX,
+ Instruments.LANGCHAIN,
+ ],
+ )
self.tracer_wrapper = TracerWrapper()
- def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
- from opentelemetry.trace import SpanKind
+ def log_event(
+ self,
+ kwargs,
+ response_obj,
+ start_time,
+ end_time,
+ user_id,
+ print_verbose,
+ level="DEFAULT",
+ status_message=None,
+ ):
+ from opentelemetry import trace
+ from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.semconv.ai import SpanAttributes
try:
+ print_verbose(
+ f"Traceloop Logging - Enters logging function for model {kwargs}"
+ )
+
tracer = self.tracer_wrapper.get_tracer()
- model = kwargs.get("model")
-
- # LiteLLM uses the standard OpenAI library, so it's already handled by Traceloop SDK
- if kwargs.get("litellm_params").get("custom_llm_provider") == "openai":
- return
-
optional_params = kwargs.get("optional_params", {})
- with tracer.start_as_current_span(
- "litellm.completion",
- kind=SpanKind.CLIENT,
- ) as span:
- if span.is_recording():
+ span = tracer.start_span(
+ "litellm.completion", kind=SpanKind.CLIENT, start_time=start_time
+ )
+
+ if span.is_recording():
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_MODEL, kwargs.get("model")
+ )
+ if "stop" in optional_params:
span.set_attribute(
- SpanAttributes.LLM_REQUEST_MODEL, kwargs.get("model")
+ SpanAttributes.LLM_CHAT_STOP_SEQUENCES,
+ optional_params.get("stop"),
)
- if "stop" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_CHAT_STOP_SEQUENCES,
- optional_params.get("stop"),
- )
- if "frequency_penalty" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_FREQUENCY_PENALTY,
- optional_params.get("frequency_penalty"),
- )
- if "presence_penalty" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_PRESENCE_PENALTY,
- optional_params.get("presence_penalty"),
- )
- if "top_p" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_TOP_P, optional_params.get("top_p")
- )
- if "tools" in optional_params or "functions" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_REQUEST_FUNCTIONS,
- optional_params.get(
- "tools", optional_params.get("functions")
- ),
- )
- if "user" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_USER, optional_params.get("user")
- )
- if "max_tokens" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_REQUEST_MAX_TOKENS,
- kwargs.get("max_tokens"),
- )
- if "temperature" in optional_params:
- span.set_attribute(
- SpanAttributes.LLM_TEMPERATURE, kwargs.get("temperature")
- )
-
- for idx, prompt in enumerate(kwargs.get("messages")):
- span.set_attribute(
- f"{SpanAttributes.LLM_PROMPTS}.{idx}.role",
- prompt.get("role"),
- )
- span.set_attribute(
- f"{SpanAttributes.LLM_PROMPTS}.{idx}.content",
- prompt.get("content"),
- )
-
+ if "frequency_penalty" in optional_params:
span.set_attribute(
- SpanAttributes.LLM_RESPONSE_MODEL, response_obj.get("model")
+ SpanAttributes.LLM_FREQUENCY_PENALTY,
+ optional_params.get("frequency_penalty"),
+ )
+ if "presence_penalty" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_PRESENCE_PENALTY,
+ optional_params.get("presence_penalty"),
+ )
+ if "top_p" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_TOP_P, optional_params.get("top_p")
+ )
+ if "tools" in optional_params or "functions" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_FUNCTIONS,
+ optional_params.get("tools", optional_params.get("functions")),
+ )
+ if "user" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_USER, optional_params.get("user")
+ )
+ if "max_tokens" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_MAX_TOKENS,
+ kwargs.get("max_tokens"),
+ )
+ if "temperature" in optional_params:
+ span.set_attribute(
+ SpanAttributes.LLM_REQUEST_TEMPERATURE,
+ kwargs.get("temperature"),
)
- usage = response_obj.get("usage")
- if usage:
- span.set_attribute(
- SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
- usage.get("total_tokens"),
- )
- span.set_attribute(
- SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
- usage.get("completion_tokens"),
- )
- span.set_attribute(
- SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
- usage.get("prompt_tokens"),
- )
- for idx, choice in enumerate(response_obj.get("choices")):
- span.set_attribute(
- f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.finish_reason",
- choice.get("finish_reason"),
- )
- span.set_attribute(
- f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.role",
- choice.get("message").get("role"),
- )
- span.set_attribute(
- f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.content",
- choice.get("message").get("content"),
- )
+ for idx, prompt in enumerate(kwargs.get("messages")):
+ span.set_attribute(
+ f"{SpanAttributes.LLM_PROMPTS}.{idx}.role",
+ prompt.get("role"),
+ )
+ span.set_attribute(
+ f"{SpanAttributes.LLM_PROMPTS}.{idx}.content",
+ prompt.get("content"),
+ )
+
+ span.set_attribute(
+ SpanAttributes.LLM_RESPONSE_MODEL, response_obj.get("model")
+ )
+ usage = response_obj.get("usage")
+ if usage:
+ span.set_attribute(
+ SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
+ usage.get("total_tokens"),
+ )
+ span.set_attribute(
+ SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
+ usage.get("completion_tokens"),
+ )
+ span.set_attribute(
+ SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
+ usage.get("prompt_tokens"),
+ )
+
+ for idx, choice in enumerate(response_obj.get("choices")):
+ span.set_attribute(
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.finish_reason",
+ choice.get("finish_reason"),
+ )
+ span.set_attribute(
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.role",
+ choice.get("message").get("role"),
+ )
+ span.set_attribute(
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.content",
+ choice.get("message").get("content"),
+ )
+
+ if (
+ level == "ERROR"
+ and status_message is not None
+ and isinstance(status_message, str)
+ ):
+ span.record_exception(Exception(status_message))
+ span.set_status(Status(StatusCode.ERROR, status_message))
+
+ span.end(end_time)
except Exception as e:
print_verbose(f"Traceloop Layer Error - {e}")
diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py
index 97a473a2e..ec6854a0f 100644
--- a/litellm/llms/anthropic.py
+++ b/litellm/llms/anthropic.py
@@ -10,6 +10,7 @@ from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
+from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
class AnthropicConstants(Enum):
@@ -93,6 +94,7 @@ class AnthropicConfig:
"max_tokens",
"tools",
"tool_choice",
+ "extra_headers",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
@@ -101,6 +103,17 @@ class AnthropicConfig:
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
+ if param == "tool_choice":
+ _tool_choice: Optional[AnthropicMessagesToolChoice] = None
+ if value == "auto":
+ _tool_choice = {"type": "auto"}
+ elif value == "required":
+ _tool_choice = {"type": "any"}
+ elif isinstance(value, dict):
+ _tool_choice = {"type": "tool", "name": value["function"]["name"]}
+
+ if _tool_choice is not None:
+ optional_params["tool_choice"] = _tool_choice
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "stop":
@@ -366,13 +379,12 @@ class AnthropicChatCompletion(BaseLLM):
logger_fn=None,
headers={},
):
- self.async_handler = AsyncHTTPHandler(
- timeout=httpx.Timeout(timeout=600.0, connect=5.0)
+
+ async_handler = AsyncHTTPHandler(
+ timeout=httpx.Timeout(timeout=600.0, connect=20.0)
)
data["stream"] = True
- response = await self.async_handler.post(
- api_base, headers=headers, data=json.dumps(data), stream=True
- )
+ response = await async_handler.post(api_base, headers=headers, json=data)
if response.status_code != 200:
raise AnthropicError(
@@ -408,12 +420,10 @@ class AnthropicChatCompletion(BaseLLM):
logger_fn=None,
headers={},
) -> Union[ModelResponse, CustomStreamWrapper]:
- self.async_handler = AsyncHTTPHandler(
+ async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
- response = await self.async_handler.post(
- api_base, headers=headers, data=json.dumps(data)
- )
+ response = await async_handler.post(api_base, headers=headers, json=data)
if stream and _is_function_call:
return self.process_streaming_response(
model=model,
@@ -504,7 +514,9 @@ class AnthropicChatCompletion(BaseLLM):
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
- headers["anthropic-beta"] = "tools-2024-04-04"
+ if "anthropic-beta" not in headers:
+ # default to v1 of "anthropic-beta"
+ headers["anthropic-beta"] = "tools-2024-05-16"
anthropic_tools = []
for tool in optional_params["tools"]:
diff --git a/litellm/llms/base.py b/litellm/llms/base.py
index d940d9471..8c2f5101e 100644
--- a/litellm/llms/base.py
+++ b/litellm/llms/base.py
@@ -21,7 +21,7 @@ class BaseLLM:
messages: list,
print_verbose,
encoding,
- ) -> litellm.utils.ModelResponse:
+ ) -> Union[litellm.utils.ModelResponse, litellm.utils.CustomStreamWrapper]:
"""
Helper function to process the response across sync + async completion calls
"""
diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py
index 1ff3767bd..337055dc2 100644
--- a/litellm/llms/bedrock_httpx.py
+++ b/litellm/llms/bedrock_httpx.py
@@ -1,6 +1,6 @@
# What is this?
## Initial implementation of calling bedrock via httpx client (allows for async calls).
-## V0 - just covers cohere command-r support
+## V1 - covers cohere + anthropic claude-3 support
import os, types
import json
@@ -29,13 +29,22 @@ from litellm.utils import (
get_secret,
Logging,
)
-import litellm
-from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
+import litellm, uuid
+from .prompt_templates.factory import (
+ prompt_factory,
+ custom_prompt,
+ cohere_message_pt,
+ construct_tool_use_system_prompt,
+ extract_between_tags,
+ parse_xml_params,
+ contains_tag,
+)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
import httpx # type: ignore
-from .bedrock import BedrockError, convert_messages_to_prompt
+from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
from litellm.types.llms.bedrock import *
+import urllib.parse
class AmazonCohereChatConfig:
@@ -280,7 +289,8 @@ class BedrockLLM(BaseLLM):
messages: List,
print_verbose,
encoding,
- ) -> ModelResponse:
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
+ provider = model.split(".")[0]
## LOGGING
logging_obj.post_call(
input=messages,
@@ -297,26 +307,210 @@ class BedrockLLM(BaseLLM):
raise BedrockError(message=response.text, status_code=422)
try:
- model_response.choices[0].message.content = completion_response["text"] # type: ignore
+ if provider == "cohere":
+ if "text" in completion_response:
+ outputText = completion_response["text"] # type: ignore
+ elif "generations" in completion_response:
+ outputText = completion_response["generations"][0]["text"]
+ model_response["finish_reason"] = map_finish_reason(
+ completion_response["generations"][0]["finish_reason"]
+ )
+ elif provider == "anthropic":
+ if model.startswith("anthropic.claude-3"):
+ json_schemas: dict = {}
+ _is_function_call = False
+ ## Handle Tool Calling
+ if "tools" in optional_params:
+ _is_function_call = True
+ for tool in optional_params["tools"]:
+ json_schemas[tool["function"]["name"]] = tool[
+ "function"
+ ].get("parameters", None)
+ outputText = completion_response.get("content")[0].get("text", None)
+ if outputText is not None and contains_tag(
+ "invoke", outputText
+ ): # OUTPUT PARSE FUNCTION CALL
+ function_name = extract_between_tags("tool_name", outputText)[0]
+ function_arguments_str = extract_between_tags(
+ "invoke", outputText
+ )[0].strip()
+ function_arguments_str = (
+ f"{function_arguments_str}"
+ )
+ function_arguments = parse_xml_params(
+ function_arguments_str,
+ json_schema=json_schemas.get(
+ function_name, None
+ ), # check if we have a json schema for this function name)
+ )
+ _message = litellm.Message(
+ tool_calls=[
+ {
+ "id": f"call_{uuid.uuid4()}",
+ "type": "function",
+ "function": {
+ "name": function_name,
+ "arguments": json.dumps(function_arguments),
+ },
+ }
+ ],
+ content=None,
+ )
+ model_response.choices[0].message = _message # type: ignore
+ model_response._hidden_params["original_response"] = (
+ outputText # allow user to access raw anthropic tool calling response
+ )
+ if (
+ _is_function_call == True
+ and stream is not None
+ and stream == True
+ ):
+ print_verbose(
+ f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
+ )
+ # return an iterator
+ streaming_model_response = ModelResponse(stream=True)
+ streaming_model_response.choices[0].finish_reason = getattr(
+ model_response.choices[0], "finish_reason", "stop"
+ )
+ # streaming_model_response.choices = [litellm.utils.StreamingChoices()]
+ streaming_choice = litellm.utils.StreamingChoices()
+ streaming_choice.index = model_response.choices[0].index
+ _tool_calls = []
+ print_verbose(
+ f"type of model_response.choices[0]: {type(model_response.choices[0])}"
+ )
+ print_verbose(
+ f"type of streaming_choice: {type(streaming_choice)}"
+ )
+ if isinstance(model_response.choices[0], litellm.Choices):
+ if getattr(
+ model_response.choices[0].message, "tool_calls", None
+ ) is not None and isinstance(
+ model_response.choices[0].message.tool_calls, list
+ ):
+ for tool_call in model_response.choices[
+ 0
+ ].message.tool_calls:
+ _tool_call = {**tool_call.dict(), "index": 0}
+ _tool_calls.append(_tool_call)
+ delta_obj = litellm.utils.Delta(
+ content=getattr(
+ model_response.choices[0].message, "content", None
+ ),
+ role=model_response.choices[0].message.role,
+ tool_calls=_tool_calls,
+ )
+ streaming_choice.delta = delta_obj
+ streaming_model_response.choices = [streaming_choice]
+ completion_stream = ModelResponseIterator(
+ model_response=streaming_model_response
+ )
+ print_verbose(
+ f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
+ )
+ return litellm.CustomStreamWrapper(
+ completion_stream=completion_stream,
+ model=model,
+ custom_llm_provider="cached_response",
+ logging_obj=logging_obj,
+ )
+
+ model_response["finish_reason"] = map_finish_reason(
+ completion_response.get("stop_reason", "")
+ )
+ _usage = litellm.Usage(
+ prompt_tokens=completion_response["usage"]["input_tokens"],
+ completion_tokens=completion_response["usage"]["output_tokens"],
+ total_tokens=completion_response["usage"]["input_tokens"]
+ + completion_response["usage"]["output_tokens"],
+ )
+ setattr(model_response, "usage", _usage)
+ else:
+ outputText = completion_response["completion"]
+
+ model_response["finish_reason"] = completion_response["stop_reason"]
+ elif provider == "ai21":
+ outputText = (
+ completion_response.get("completions")[0].get("data").get("text")
+ )
+ elif provider == "meta":
+ outputText = completion_response["generation"]
+ elif provider == "mistral":
+ outputText = completion_response["outputs"][0]["text"]
+ model_response["finish_reason"] = completion_response["outputs"][0][
+ "stop_reason"
+ ]
+ else: # amazon titan
+ outputText = completion_response.get("results")[0].get("outputText")
except Exception as e:
- raise BedrockError(message=response.text, status_code=422)
+ raise BedrockError(
+ message="Error processing={}, Received error={}".format(
+ response.text, str(e)
+ ),
+ status_code=422,
+ )
+
+ try:
+ if (
+ len(outputText) > 0
+ and hasattr(model_response.choices[0], "message")
+ and getattr(model_response.choices[0].message, "tool_calls", None)
+ is None
+ ):
+ model_response["choices"][0]["message"]["content"] = outputText
+ elif (
+ hasattr(model_response.choices[0], "message")
+ and getattr(model_response.choices[0].message, "tool_calls", None)
+ is not None
+ ):
+ pass
+ else:
+ raise Exception()
+ except:
+ raise BedrockError(
+ message=json.dumps(outputText), status_code=response.status_code
+ )
+
+ if stream and provider == "ai21":
+ streaming_model_response = ModelResponse(stream=True)
+ streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
+ 0
+ ].finish_reason
+ # streaming_model_response.choices = [litellm.utils.StreamingChoices()]
+ streaming_choice = litellm.utils.StreamingChoices()
+ streaming_choice.index = model_response.choices[0].index
+ delta_obj = litellm.utils.Delta(
+ content=getattr(model_response.choices[0].message, "content", None),
+ role=model_response.choices[0].message.role,
+ )
+ streaming_choice.delta = delta_obj
+ streaming_model_response.choices = [streaming_choice]
+ mri = ModelResponseIterator(model_response=streaming_model_response)
+ return CustomStreamWrapper(
+ completion_stream=mri,
+ model=model,
+ custom_llm_provider="cached_response",
+ logging_obj=logging_obj,
+ )
## CALCULATING USAGE - bedrock returns usage in the headers
- prompt_tokens = int(
- response.headers.get(
- "x-amzn-bedrock-input-token-count",
- len(encoding.encode("".join(m.get("content", "") for m in messages))),
- )
+ bedrock_input_tokens = response.headers.get(
+ "x-amzn-bedrock-input-token-count", None
)
+ bedrock_output_tokens = response.headers.get(
+ "x-amzn-bedrock-output-token-count", None
+ )
+
+ prompt_tokens = int(
+ bedrock_input_tokens or litellm.token_counter(messages=messages)
+ )
+
completion_tokens = int(
- response.headers.get(
- "x-amzn-bedrock-output-token-count",
- len(
- encoding.encode(
- model_response.choices[0].message.content, # type: ignore
- disallowed_special=(),
- )
- ),
+ bedrock_output_tokens
+ or litellm.token_counter(
+ text=model_response.choices[0].message.content, # type: ignore
+ count_response_tokens=True,
)
)
@@ -331,6 +525,16 @@ class BedrockLLM(BaseLLM):
return model_response
+ def encode_model_id(self, model_id: str) -> str:
+ """
+ Double encode the model ID to ensure it matches the expected double-encoded format.
+ Args:
+ model_id (str): The model ID to encode.
+ Returns:
+ str: The double-encoded model ID.
+ """
+ return urllib.parse.quote(model_id, safe="")
+
def completion(
self,
model: str,
@@ -359,6 +563,13 @@ class BedrockLLM(BaseLLM):
## SETUP ##
stream = optional_params.pop("stream", None)
+ modelId = optional_params.pop("model_id", None)
+ if modelId is not None:
+ modelId = self.encode_model_id(model_id=modelId)
+ else:
+ modelId = model
+
+ provider = model.split(".")[0]
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
@@ -414,19 +625,18 @@ class BedrockLLM(BaseLLM):
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
- if stream is not None and stream == True:
- endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
+ if (stream is not None and stream == True) and provider != "ai21":
+ endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
else:
- endpoint_url = f"{endpoint_url}/model/{model}/invoke"
+ endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
- provider = model.split(".")[0]
prompt, chat_history = self.convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
inference_params = copy.deepcopy(optional_params)
-
+ json_schemas: dict = {}
if provider == "cohere":
if model.startswith("cohere.command-r"):
## LOAD CONFIG
@@ -453,8 +663,114 @@ class BedrockLLM(BaseLLM):
True # cohere requires stream = True in inference params
)
data = json.dumps({"prompt": prompt, **inference_params})
+ elif provider == "anthropic":
+ if model.startswith("anthropic.claude-3"):
+ # Separate system prompt from rest of message
+ system_prompt_idx: list[int] = []
+ system_messages: list[str] = []
+ for idx, message in enumerate(messages):
+ if message["role"] == "system":
+ system_messages.append(message["content"])
+ system_prompt_idx.append(idx)
+ if len(system_prompt_idx) > 0:
+ inference_params["system"] = "\n".join(system_messages)
+ messages = [
+ i for j, i in enumerate(messages) if j not in system_prompt_idx
+ ]
+ # Format rest of message according to anthropic guidelines
+ messages = prompt_factory(
+ model=model, messages=messages, custom_llm_provider="anthropic_xml"
+ ) # type: ignore
+ ## LOAD CONFIG
+ config = litellm.AmazonAnthropicClaude3Config.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+ ## Handle Tool Calling
+ if "tools" in inference_params:
+ _is_function_call = True
+ for tool in inference_params["tools"]:
+ json_schemas[tool["function"]["name"]] = tool["function"].get(
+ "parameters", None
+ )
+ tool_calling_system_prompt = construct_tool_use_system_prompt(
+ tools=inference_params["tools"]
+ )
+ inference_params["system"] = (
+ inference_params.get("system", "\n")
+ + tool_calling_system_prompt
+ ) # add the anthropic tool calling prompt to the system prompt
+ inference_params.pop("tools")
+ data = json.dumps({"messages": messages, **inference_params})
+ else:
+ ## LOAD CONFIG
+ config = litellm.AmazonAnthropicConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+ data = json.dumps({"prompt": prompt, **inference_params})
+ elif provider == "ai21":
+ ## LOAD CONFIG
+ config = litellm.AmazonAI21Config.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+
+ data = json.dumps({"prompt": prompt, **inference_params})
+ elif provider == "mistral":
+ ## LOAD CONFIG
+ config = litellm.AmazonMistralConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+
+ data = json.dumps({"prompt": prompt, **inference_params})
+ elif provider == "amazon": # amazon titan
+ ## LOAD CONFIG
+ config = litellm.AmazonTitanConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+
+ data = json.dumps(
+ {
+ "inputText": prompt,
+ "textGenerationConfig": inference_params,
+ }
+ )
+ elif provider == "meta":
+ ## LOAD CONFIG
+ config = litellm.AmazonLlamaConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+ data = json.dumps({"prompt": prompt, **inference_params})
else:
- raise Exception("UNSUPPORTED PROVIDER")
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key="",
+ additional_args={
+ "complete_input_dict": inference_params,
+ },
+ )
+ raise Exception(
+ "Bedrock HTTPX: Unsupported provider={}, model={}".format(
+ provider, model
+ )
+ )
## COMPLETION CALL
@@ -482,7 +798,7 @@ class BedrockLLM(BaseLLM):
if acompletion:
if isinstance(client, HTTPHandler):
client = None
- if stream:
+ if stream == True and provider != "ai21":
return self.async_streaming(
model=model,
messages=messages,
@@ -511,7 +827,7 @@ class BedrockLLM(BaseLLM):
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
- stream=False,
+ stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
@@ -528,7 +844,7 @@ class BedrockLLM(BaseLLM):
self.client = HTTPHandler(**_params) # type: ignore
else:
self.client = client
- if stream is not None and stream == True:
+ if (stream is not None and stream == True) and provider != "ai21":
response = self.client.post(
url=prepped.url,
headers=prepped.headers, # type: ignore
@@ -541,7 +857,7 @@ class BedrockLLM(BaseLLM):
status_code=response.status_code, message=response.text
)
- decoder = AWSEventStreamDecoder()
+ decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper(
@@ -550,15 +866,24 @@ class BedrockLLM(BaseLLM):
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=streaming_response,
+ additional_args={"complete_input_dict": data},
+ )
return streaming_response
- response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
-
try:
+ response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
+ except httpx.TimeoutException as e:
+ raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
@@ -591,7 +916,7 @@ class BedrockLLM(BaseLLM):
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
- ) -> ModelResponse:
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
if client is None:
_params = {}
if timeout is not None:
@@ -602,12 +927,20 @@ class BedrockLLM(BaseLLM):
else:
self.client = client # type: ignore
- response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
+ try:
+ response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise BedrockError(status_code=error_code, message=response.text)
+ except httpx.TimeoutException as e:
+ raise BedrockError(status_code=408, message="Timeout error occurred.")
+
return self.process_response(
model=model,
response=response,
model_response=model_response,
- stream=stream,
+ stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj,
api_key="",
data=data,
@@ -650,7 +983,7 @@ class BedrockLLM(BaseLLM):
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.text)
- decoder = AWSEventStreamDecoder()
+ decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper(
@@ -659,6 +992,15 @@ class BedrockLLM(BaseLLM):
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=streaming_response,
+ additional_args={"complete_input_dict": data},
+ )
+
return streaming_response
def embedding(self, *args, **kwargs):
@@ -676,11 +1018,70 @@ def get_response_stream_shape():
class AWSEventStreamDecoder:
- def __init__(self) -> None:
+ def __init__(self, model: str) -> None:
from botocore.parsers import EventStreamJSONParser
+ self.model = model
self.parser = EventStreamJSONParser()
+ def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
+ text = ""
+ is_finished = False
+ finish_reason = ""
+ if "outputText" in chunk_data:
+ text = chunk_data["outputText"]
+ # ai21 mapping
+ if "ai21" in self.model: # fake ai21 streaming
+ text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
+ is_finished = True
+ finish_reason = "stop"
+ ######## bedrock.anthropic mappings ###############
+ elif "completion" in chunk_data: # not claude-3
+ text = chunk_data["completion"] # bedrock.anthropic
+ stop_reason = chunk_data.get("stop_reason", None)
+ if stop_reason != None:
+ is_finished = True
+ finish_reason = stop_reason
+ elif "delta" in chunk_data:
+ if chunk_data["delta"].get("text", None) is not None:
+ text = chunk_data["delta"]["text"]
+ stop_reason = chunk_data["delta"].get("stop_reason", None)
+ if stop_reason != None:
+ is_finished = True
+ finish_reason = stop_reason
+ ######## bedrock.mistral mappings ###############
+ elif "outputs" in chunk_data:
+ if (
+ len(chunk_data["outputs"]) == 1
+ and chunk_data["outputs"][0].get("text", None) is not None
+ ):
+ text = chunk_data["outputs"][0]["text"]
+ stop_reason = chunk_data.get("stop_reason", None)
+ if stop_reason != None:
+ is_finished = True
+ finish_reason = stop_reason
+ ######## bedrock.cohere mappings ###############
+ # meta mapping
+ elif "generation" in chunk_data:
+ text = chunk_data["generation"] # bedrock.meta
+ # cohere mapping
+ elif "text" in chunk_data:
+ text = chunk_data["text"] # bedrock.cohere
+ # cohere mapping for finish reason
+ elif "finish_reason" in chunk_data:
+ finish_reason = chunk_data["finish_reason"]
+ is_finished = True
+ elif chunk_data.get("completionReason", None):
+ is_finished = True
+ finish_reason = chunk_data["completionReason"]
+ return GenericStreamingChunk(
+ **{
+ "text": text,
+ "is_finished": is_finished,
+ "finish_reason": finish_reason,
+ }
+ )
+
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
@@ -693,12 +1094,7 @@ class AWSEventStreamDecoder:
if message:
# sse_event = ServerSentEvent(data=message, event="completion")
_data = json.loads(message)
- streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
- text=_data.get("text", ""),
- is_finished=_data.get("is_finished", False),
- finish_reason=_data.get("finish_reason", ""),
- )
- yield streaming_chunk
+ yield self._chunk_parser(chunk_data=_data)
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
@@ -713,12 +1109,7 @@ class AWSEventStreamDecoder:
message = self._parse_message_from_event(event)
if message:
_data = json.loads(message)
- streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
- text=_data.get("text", ""),
- is_finished=_data.get("is_finished", False),
- finish_reason=_data.get("finish_reason", ""),
- )
- yield streaming_chunk
+ yield self._chunk_parser(chunk_data=_data)
def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict()
diff --git a/litellm/llms/clarifai.py b/litellm/llms/clarifai.py
index e07a8d9e8..4610911e1 100644
--- a/litellm/llms/clarifai.py
+++ b/litellm/llms/clarifai.py
@@ -14,28 +14,25 @@ class ClarifaiError(Exception):
def __init__(self, status_code, message, url):
self.status_code = status_code
self.message = message
- self.request = httpx.Request(
- method="POST", url=url
- )
+ self.request = httpx.Request(method="POST", url=url)
self.response = httpx.Response(status_code=status_code, request=self.request)
- super().__init__(
- self.message
- )
+ super().__init__(self.message)
+
class ClarifaiConfig:
"""
Reference: https://clarifai.com/meta/Llama-2/models/llama2-70b-chat
- TODO fill in the details
"""
+
max_tokens: Optional[int] = None
temperature: Optional[int] = None
top_k: Optional[int] = None
def __init__(
- self,
- max_tokens: Optional[int] = None,
- temperature: Optional[int] = None,
- top_k: Optional[int] = None,
+ self,
+ max_tokens: Optional[int] = None,
+ temperature: Optional[int] = None,
+ top_k: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
@@ -60,6 +57,7 @@ class ClarifaiConfig:
and v is not None
}
+
def validate_environment(api_key):
headers = {
"accept": "application/json",
@@ -69,42 +67,37 @@ def validate_environment(api_key):
headers["Authorization"] = f"Bearer {api_key}"
return headers
-def completions_to_model(payload):
- # if payload["n"] != 1:
- # raise HTTPException(
- # status_code=422,
- # detail="Only one generation is supported. Please set candidate_count to 1.",
- # )
- params = {}
- if temperature := payload.get("temperature"):
- params["temperature"] = temperature
- if max_tokens := payload.get("max_tokens"):
- params["max_tokens"] = max_tokens
- return {
- "inputs": [{"data": {"text": {"raw": payload["prompt"]}}}],
- "model": {"output_info": {"params": params}},
-}
-
+def completions_to_model(payload):
+ # if payload["n"] != 1:
+ # raise HTTPException(
+ # status_code=422,
+ # detail="Only one generation is supported. Please set candidate_count to 1.",
+ # )
+
+ params = {}
+ if temperature := payload.get("temperature"):
+ params["temperature"] = temperature
+ if max_tokens := payload.get("max_tokens"):
+ params["max_tokens"] = max_tokens
+ return {
+ "inputs": [{"data": {"text": {"raw": payload["prompt"]}}}],
+ "model": {"output_info": {"params": params}},
+ }
+
+
def process_response(
- model,
- prompt,
- response,
- model_response,
- api_key,
- data,
- encoding,
- logging_obj
- ):
+ model, prompt, response, model_response, api_key, data, encoding, logging_obj
+):
logging_obj.post_call(
- input=prompt,
- api_key=api_key,
- original_response=response.text,
- additional_args={"complete_input_dict": data},
- )
- ## RESPONSE OBJECT
+ input=prompt,
+ api_key=api_key,
+ original_response=response.text,
+ additional_args={"complete_input_dict": data},
+ )
+ ## RESPONSE OBJECT
try:
- completion_response = response.json()
+ completion_response = response.json()
except Exception:
raise ClarifaiError(
message=response.text, status_code=response.status_code, url=model
@@ -119,7 +112,7 @@ def process_response(
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason="stop",
- index=idx + 1, #check
+ index=idx + 1, # check
message=message_obj,
)
choices_list.append(choice_obj)
@@ -143,53 +136,56 @@ def process_response(
)
return model_response
+
def convert_model_to_url(model: str, api_base: str):
user_id, app_id, model_id = model.split(".")
return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs"
+
def get_prompt_model_name(url: str):
clarifai_model_name = url.split("/")[-2]
if "claude" in clarifai_model_name:
return "anthropic", clarifai_model_name.replace("_", ".")
- if ("llama" in clarifai_model_name)or ("mistral" in clarifai_model_name):
+ if ("llama" in clarifai_model_name) or ("mistral" in clarifai_model_name):
return "", "meta-llama/llama-2-chat"
else:
return "", clarifai_model_name
+
async def async_completion(
- model: str,
- prompt: str,
- api_base: str,
- custom_prompt_dict: dict,
- model_response: ModelResponse,
- print_verbose: Callable,
- encoding,
- api_key,
- logging_obj,
- data=None,
- optional_params=None,
- litellm_params=None,
- logger_fn=None,
- headers={}):
-
- async_handler = AsyncHTTPHandler(
- timeout=httpx.Timeout(timeout=600.0, connect=5.0)
- )
+ model: str,
+ prompt: str,
+ api_base: str,
+ custom_prompt_dict: dict,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ api_key,
+ logging_obj,
+ data=None,
+ optional_params=None,
+ litellm_params=None,
+ logger_fn=None,
+ headers={},
+):
+
+ async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = await async_handler.post(
- api_base, headers=headers, data=json.dumps(data)
- )
-
- return process_response(
- model=model,
- prompt=prompt,
- response=response,
- model_response=model_response,
- api_key=api_key,
- data=data,
- encoding=encoding,
- logging_obj=logging_obj,
+ api_base, headers=headers, data=json.dumps(data)
)
+ return process_response(
+ model=model,
+ prompt=prompt,
+ response=response,
+ model_response=model_response,
+ api_key=api_key,
+ data=data,
+ encoding=encoding,
+ logging_obj=logging_obj,
+ )
+
+
def completion(
model: str,
messages: list,
@@ -207,14 +203,12 @@ def completion(
):
headers = validate_environment(api_key)
model = convert_model_to_url(model, api_base)
- prompt = " ".join(message["content"] for message in messages) # TODO
+ prompt = " ".join(message["content"] for message in messages) # TODO
## Load Config
config = litellm.ClarifaiConfig.get_config()
for k, v in config.items():
- if (
- k not in optional_params
- ):
+ if k not in optional_params:
optional_params[k] = v
custom_llm_provider, orig_model_name = get_prompt_model_name(model)
@@ -223,14 +217,14 @@ def completion(
model=orig_model_name,
messages=messages,
api_key=api_key,
- custom_llm_provider="clarifai"
+ custom_llm_provider="clarifai",
)
else:
prompt = prompt_factory(
model=orig_model_name,
messages=messages,
api_key=api_key,
- custom_llm_provider=custom_llm_provider
+ custom_llm_provider=custom_llm_provider,
)
# print(prompt); exit(0)
@@ -240,7 +234,6 @@ def completion(
}
data = completions_to_model(data)
-
## LOGGING
logging_obj.pre_call(
input=prompt,
@@ -251,7 +244,7 @@ def completion(
"api_base": api_base,
},
)
- if acompletion==True:
+ if acompletion == True:
return async_completion(
model=model,
prompt=prompt,
@@ -271,15 +264,17 @@ def completion(
else:
## COMPLETION CALL
response = requests.post(
- model,
- headers=headers,
- data=json.dumps(data),
- )
+ model,
+ headers=headers,
+ data=json.dumps(data),
+ )
# print(response.content); exit()
if response.status_code != 200:
- raise ClarifaiError(status_code=response.status_code, message=response.text, url=model)
-
+ raise ClarifaiError(
+ status_code=response.status_code, message=response.text, url=model
+ )
+
if "stream" in optional_params and optional_params["stream"] == True:
completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper(
@@ -287,11 +282,11 @@ def completion(
model=model,
custom_llm_provider="clarifai",
logging_obj=logging_obj,
- )
+ )
return stream_response
-
+
else:
- return process_response(
+ return process_response(
model=model,
prompt=prompt,
response=response,
@@ -299,8 +294,9 @@ def completion(
api_key=api_key,
data=data,
encoding=encoding,
- logging_obj=logging_obj)
-
+ logging_obj=logging_obj,
+ )
+
class ModelResponseIterator:
def __init__(self, model_response):
@@ -325,4 +321,4 @@ class ModelResponseIterator:
if self.is_done:
raise StopAsyncIteration
self.is_done = True
- return self.model_response
\ No newline at end of file
+ return self.model_response
diff --git a/litellm/llms/cohere.py b/litellm/llms/cohere.py
index 0ebdf38f1..14a66b54a 100644
--- a/litellm/llms/cohere.py
+++ b/litellm/llms/cohere.py
@@ -117,6 +117,7 @@ class CohereConfig:
def validate_environment(api_key):
headers = {
+ "Request-Source":"unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
diff --git a/litellm/llms/cohere_chat.py b/litellm/llms/cohere_chat.py
index e4de6ddcb..8ae839243 100644
--- a/litellm/llms/cohere_chat.py
+++ b/litellm/llms/cohere_chat.py
@@ -112,6 +112,7 @@ class CohereChatConfig:
def validate_environment(api_key):
headers = {
+ "Request-Source":"unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py
index 0adbd95bf..8b5f11398 100644
--- a/litellm/llms/custom_httpx/http_handler.py
+++ b/litellm/llms/custom_httpx/http_handler.py
@@ -7,8 +7,12 @@ _DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
class AsyncHTTPHandler:
def __init__(
- self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
+ self,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ concurrent_limit=1000,
):
+ if timeout is None:
+ timeout = _DEFAULT_TIMEOUT
# Create a client with a connection pool
self.client = httpx.AsyncClient(
timeout=timeout,
@@ -39,12 +43,13 @@ class AsyncHTTPHandler:
self,
url: str,
data: Optional[Union[dict, str]] = None, # type: ignore
+ json: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
req = self.client.build_request(
- "POST", url, data=data, params=params, headers=headers # type: ignore
+ "POST", url, data=data, json=json, params=params, headers=headers # type: ignore
)
response = await self.client.send(req, stream=stream)
return response
@@ -59,7 +64,7 @@ class AsyncHTTPHandler:
class HTTPHandler:
def __init__(
self,
- timeout: Optional[httpx.Timeout] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
concurrent_limit=1000,
client: Optional[httpx.Client] = None,
):
diff --git a/litellm/llms/databricks.py b/litellm/llms/databricks.py
new file mode 100644
index 000000000..7b2013710
--- /dev/null
+++ b/litellm/llms/databricks.py
@@ -0,0 +1,696 @@
+# What is this?
+## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
+import os, types
+import json
+from enum import Enum
+import requests, copy # type: ignore
+import time
+from typing import Callable, Optional, List, Union, Tuple, Literal
+from litellm.utils import (
+ ModelResponse,
+ Usage,
+ map_finish_reason,
+ CustomStreamWrapper,
+ EmbeddingResponse,
+)
+import litellm
+from .prompt_templates.factory import prompt_factory, custom_prompt
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from .base import BaseLLM
+import httpx # type: ignore
+from litellm.types.llms.databricks import GenericStreamingChunk
+from litellm.types.utils import ProviderField
+
+
+class DatabricksError(Exception):
+ def __init__(self, status_code, message):
+ self.status_code = status_code
+ self.message = message
+ self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
+ self.response = httpx.Response(status_code=status_code, request=self.request)
+ super().__init__(
+ self.message
+ ) # Call the base class constructor with the parameters it needs
+
+
+class DatabricksConfig:
+ """
+ Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
+ """
+
+ max_tokens: Optional[int] = None
+ temperature: Optional[int] = None
+ top_p: Optional[int] = None
+ top_k: Optional[int] = None
+ stop: Optional[Union[List[str], str]] = None
+ n: Optional[int] = None
+
+ def __init__(
+ self,
+ max_tokens: Optional[int] = None,
+ temperature: Optional[int] = None,
+ top_p: Optional[int] = None,
+ top_k: Optional[int] = None,
+ stop: Optional[Union[List[str], str]] = None,
+ n: Optional[int] = None,
+ ) -> None:
+ locals_ = locals()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_required_params(self) -> List[ProviderField]:
+ """For a given provider, return it's required fields with a description"""
+ return [
+ ProviderField(
+ field_name="api_key",
+ field_type="string",
+ field_description="Your Databricks API Key.",
+ field_value="dapi...",
+ ),
+ ProviderField(
+ field_name="api_base",
+ field_type="string",
+ field_description="Your Databricks API Base.",
+ field_value="https://adb-..",
+ ),
+ ]
+
+ def get_supported_openai_params(self):
+ return ["stream", "stop", "temperature", "top_p", "max_tokens", "n"]
+
+ def map_openai_params(self, non_default_params: dict, optional_params: dict):
+ for param, value in non_default_params.items():
+ if param == "max_tokens":
+ optional_params["max_tokens"] = value
+ if param == "n":
+ optional_params["n"] = value
+ if param == "stream" and value == True:
+ optional_params["stream"] = value
+ if param == "temperature":
+ optional_params["temperature"] = value
+ if param == "top_p":
+ optional_params["top_p"] = value
+ if param == "stop":
+ optional_params["stop"] = value
+ return optional_params
+
+ def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
+ try:
+ text = ""
+ is_finished = False
+ finish_reason = None
+ logprobs = None
+ usage = None
+ original_chunk = None # this is used for function/tool calling
+ chunk_data = chunk_data.replace("data:", "")
+ chunk_data = chunk_data.strip()
+ if len(chunk_data) == 0:
+ return {
+ "text": "",
+ "is_finished": is_finished,
+ "finish_reason": finish_reason,
+ }
+ chunk_data_dict = json.loads(chunk_data)
+ str_line = litellm.ModelResponse(**chunk_data_dict, stream=True)
+
+ if len(str_line.choices) > 0:
+ if (
+ str_line.choices[0].delta is not None # type: ignore
+ and str_line.choices[0].delta.content is not None # type: ignore
+ ):
+ text = str_line.choices[0].delta.content # type: ignore
+ else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai
+ original_chunk = str_line
+ if str_line.choices[0].finish_reason:
+ is_finished = True
+ finish_reason = str_line.choices[0].finish_reason
+ if finish_reason == "content_filter":
+ if hasattr(str_line.choices[0], "content_filter_result"):
+ error_message = json.dumps(
+ str_line.choices[0].content_filter_result # type: ignore
+ )
+ else:
+ error_message = "Azure Response={}".format(
+ str(dict(str_line))
+ )
+ raise litellm.AzureOpenAIError(
+ status_code=400, message=error_message
+ )
+
+ # checking for logprobs
+ if (
+ hasattr(str_line.choices[0], "logprobs")
+ and str_line.choices[0].logprobs is not None
+ ):
+ logprobs = str_line.choices[0].logprobs
+ else:
+ logprobs = None
+
+ usage = getattr(str_line, "usage", None)
+
+ return GenericStreamingChunk(
+ text=text,
+ is_finished=is_finished,
+ finish_reason=finish_reason,
+ logprobs=logprobs,
+ original_chunk=original_chunk,
+ usage=usage,
+ )
+ except Exception as e:
+ raise e
+
+
+class DatabricksEmbeddingConfig:
+ """
+ Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
+ """
+
+ instruction: Optional[str] = (
+ None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
+ )
+
+ def __init__(self, instruction: Optional[str] = None) -> None:
+ locals_ = locals()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_supported_openai_params(
+ self,
+ ): # no optional openai embedding params supported
+ return []
+
+ def map_openai_params(self, non_default_params: dict, optional_params: dict):
+ return optional_params
+
+
+class DatabricksChatCompletion(BaseLLM):
+ def __init__(self) -> None:
+ super().__init__()
+
+ # makes headers for API call
+
+ def _validate_environment(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ endpoint_type: Literal["chat_completions", "embeddings"],
+ ) -> Tuple[str, dict]:
+ if api_key is None:
+ raise DatabricksError(
+ status_code=400,
+ message="Missing Databricks API Key - A call is being made to Databricks but no key is set either in the environment variables (DATABRICKS_API_KEY) or via params",
+ )
+
+ if api_base is None:
+ raise DatabricksError(
+ status_code=400,
+ message="Missing Databricks API Base - A call is being made to Databricks but no api base is set either in the environment variables (DATABRICKS_API_BASE) or via params",
+ )
+
+ headers = {
+ "Authorization": "Bearer {}".format(api_key),
+ "Content-Type": "application/json",
+ }
+
+ if endpoint_type == "chat_completions":
+ api_base = "{}/chat/completions".format(api_base)
+ elif endpoint_type == "embeddings":
+ api_base = "{}/embeddings".format(api_base)
+ return api_base, headers
+
+ def process_response(
+ self,
+ model: str,
+ response: Union[requests.Response, httpx.Response],
+ model_response: ModelResponse,
+ stream: bool,
+ logging_obj: litellm.utils.Logging,
+ optional_params: dict,
+ api_key: str,
+ data: Union[dict, str],
+ messages: List,
+ print_verbose,
+ encoding,
+ ) -> ModelResponse:
+ ## LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key=api_key,
+ original_response=response.text,
+ additional_args={"complete_input_dict": data},
+ )
+ print_verbose(f"raw model_response: {response.text}")
+ ## RESPONSE OBJECT
+ try:
+ completion_response = response.json()
+ except:
+ raise DatabricksError(
+ message=response.text, status_code=response.status_code
+ )
+ if "error" in completion_response:
+ raise DatabricksError(
+ message=str(completion_response["error"]),
+ status_code=response.status_code,
+ )
+ else:
+ text_content = ""
+ tool_calls = []
+ for content in completion_response["content"]:
+ if content["type"] == "text":
+ text_content += content["text"]
+ ## TOOL CALLING
+ elif content["type"] == "tool_use":
+ tool_calls.append(
+ {
+ "id": content["id"],
+ "type": "function",
+ "function": {
+ "name": content["name"],
+ "arguments": json.dumps(content["input"]),
+ },
+ }
+ )
+
+ _message = litellm.Message(
+ tool_calls=tool_calls,
+ content=text_content or None,
+ )
+ model_response.choices[0].message = _message # type: ignore
+ model_response._hidden_params["original_response"] = completion_response[
+ "content"
+ ] # allow user to access raw anthropic tool calling response
+
+ model_response.choices[0].finish_reason = map_finish_reason(
+ completion_response["stop_reason"]
+ )
+
+ ## CALCULATING USAGE
+ prompt_tokens = completion_response["usage"]["input_tokens"]
+ completion_tokens = completion_response["usage"]["output_tokens"]
+ total_tokens = prompt_tokens + completion_tokens
+
+ model_response["created"] = int(time.time())
+ model_response["model"] = model
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ )
+ setattr(model_response, "usage", usage) # type: ignore
+ return model_response
+
+ async def acompletion_stream_function(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ custom_prompt_dict: dict,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ api_key,
+ logging_obj,
+ stream,
+ data: dict,
+ optional_params=None,
+ litellm_params=None,
+ logger_fn=None,
+ headers={},
+ ):
+ self.async_handler = AsyncHTTPHandler(
+ timeout=httpx.Timeout(timeout=600.0, connect=5.0)
+ )
+ data["stream"] = True
+ try:
+ response = await self.async_handler.post(
+ api_base, headers=headers, data=json.dumps(data), stream=True
+ )
+ response.raise_for_status()
+
+ completion_stream = response.aiter_lines()
+ except httpx.HTTPStatusError as e:
+ raise DatabricksError(
+ status_code=e.response.status_code, message=response.text
+ )
+ except httpx.TimeoutException as e:
+ raise DatabricksError(status_code=408, message="Timeout error occurred.")
+ except Exception as e:
+ raise DatabricksError(status_code=500, message=str(e))
+
+ streamwrapper = CustomStreamWrapper(
+ completion_stream=completion_stream,
+ model=model,
+ custom_llm_provider="databricks",
+ logging_obj=logging_obj,
+ )
+ return streamwrapper
+
+ async def acompletion_function(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ custom_prompt_dict: dict,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ api_key,
+ logging_obj,
+ stream,
+ data: dict,
+ optional_params: dict,
+ litellm_params=None,
+ logger_fn=None,
+ headers={},
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ ) -> ModelResponse:
+ if timeout is None:
+ timeout = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ self.async_handler = AsyncHTTPHandler(timeout=timeout)
+
+ try:
+ response = await self.async_handler.post(
+ api_base, headers=headers, data=json.dumps(data)
+ )
+ response.raise_for_status()
+
+ response_json = response.json()
+ except httpx.HTTPStatusError as e:
+ raise DatabricksError(
+ status_code=e.response.status_code,
+ message=response.text if response else str(e),
+ )
+ except httpx.TimeoutException as e:
+ raise DatabricksError(status_code=408, message="Timeout error occurred.")
+ except Exception as e:
+ raise DatabricksError(status_code=500, message=str(e))
+
+ return ModelResponse(**response_json)
+
+ def completion(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ custom_prompt_dict: dict,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ api_key,
+ logging_obj,
+ optional_params: dict,
+ acompletion=None,
+ litellm_params=None,
+ logger_fn=None,
+ headers={},
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ ):
+ api_base, headers = self._validate_environment(
+ api_base=api_base, api_key=api_key, endpoint_type="chat_completions"
+ )
+ ## Load Config
+ config = litellm.DatabricksConfig().get_config()
+ for k, v in config.items():
+ if (
+ k not in optional_params
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
+ optional_params[k] = v
+
+ stream = optional_params.pop("stream", None)
+
+ data = {
+ "model": model,
+ "messages": messages,
+ **optional_params,
+ }
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+ if acompletion == True:
+ if (
+ stream is not None and stream == True
+ ): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
+ print_verbose("makes async anthropic streaming POST request")
+ data["stream"] = stream
+ return self.acompletion_stream_function(
+ model=model,
+ messages=messages,
+ data=data,
+ api_base=api_base,
+ custom_prompt_dict=custom_prompt_dict,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ api_key=api_key,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ stream=stream,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ headers=headers,
+ )
+ else:
+ return self.acompletion_function(
+ model=model,
+ messages=messages,
+ data=data,
+ api_base=api_base,
+ custom_prompt_dict=custom_prompt_dict,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ api_key=api_key,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ stream=stream,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ headers=headers,
+ timeout=timeout,
+ )
+ else:
+ if client is None or isinstance(client, AsyncHTTPHandler):
+ self.client = HTTPHandler(timeout=timeout) # type: ignore
+ else:
+ self.client = client
+ ## COMPLETION CALL
+ if (
+ stream is not None and stream == True
+ ): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
+ print_verbose("makes dbrx streaming POST request")
+ data["stream"] = stream
+ try:
+ response = self.client.post(
+ api_base, headers=headers, data=json.dumps(data), stream=stream
+ )
+ response.raise_for_status()
+ completion_stream = response.iter_lines()
+ except httpx.HTTPStatusError as e:
+ raise DatabricksError(
+ status_code=e.response.status_code, message=response.text
+ )
+ except httpx.TimeoutException as e:
+ raise DatabricksError(
+ status_code=408, message="Timeout error occurred."
+ )
+ except Exception as e:
+ raise DatabricksError(status_code=408, message=str(e))
+
+ streaming_response = CustomStreamWrapper(
+ completion_stream=completion_stream,
+ model=model,
+ custom_llm_provider="databricks",
+ logging_obj=logging_obj,
+ )
+ return streaming_response
+
+ else:
+ try:
+ response = self.client.post(
+ api_base, headers=headers, data=json.dumps(data)
+ )
+ response.raise_for_status()
+
+ response_json = response.json()
+ except httpx.HTTPStatusError as e:
+ raise DatabricksError(
+ status_code=e.response.status_code, message=response.text
+ )
+ except httpx.TimeoutException as e:
+ raise DatabricksError(
+ status_code=408, message="Timeout error occurred."
+ )
+ except Exception as e:
+ raise DatabricksError(status_code=500, message=str(e))
+
+ return ModelResponse(**response_json)
+
+ async def aembedding(
+ self,
+ input: list,
+ data: dict,
+ model_response: ModelResponse,
+ timeout: float,
+ api_key: str,
+ api_base: str,
+ logging_obj,
+ headers: dict,
+ client=None,
+ ) -> EmbeddingResponse:
+ response = None
+ try:
+ if client is None or isinstance(client, AsyncHTTPHandler):
+ self.async_client = AsyncHTTPHandler(timeout=timeout) # type: ignore
+ else:
+ self.async_client = client
+
+ try:
+ response = await self.async_client.post(
+ api_base,
+ headers=headers,
+ data=json.dumps(data),
+ ) # type: ignore
+
+ response.raise_for_status()
+
+ response_json = response.json()
+ except httpx.HTTPStatusError as e:
+ raise DatabricksError(
+ status_code=e.response.status_code,
+ message=response.text if response else str(e),
+ )
+ except httpx.TimeoutException as e:
+ raise DatabricksError(
+ status_code=408, message="Timeout error occurred."
+ )
+ except Exception as e:
+ raise DatabricksError(status_code=500, message=str(e))
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=response_json,
+ )
+ return EmbeddingResponse(**response_json)
+ except Exception as e:
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key=api_key,
+ original_response=str(e),
+ )
+ raise e
+
+ def embedding(
+ self,
+ model: str,
+ input: list,
+ timeout: float,
+ logging_obj,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ optional_params: dict,
+ model_response: Optional[litellm.utils.EmbeddingResponse] = None,
+ client=None,
+ aembedding=None,
+ ) -> EmbeddingResponse:
+ api_base, headers = self._validate_environment(
+ api_base=api_base, api_key=api_key, endpoint_type="embeddings"
+ )
+ model = model
+ data = {"model": model, "input": input, **optional_params}
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=input,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data, "api_base": api_base},
+ )
+
+ if aembedding == True:
+ return self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, headers=headers) # type: ignore
+ if client is None or isinstance(client, AsyncHTTPHandler):
+ self.client = HTTPHandler(timeout=timeout) # type: ignore
+ else:
+ self.client = client
+
+ ## EMBEDDING CALL
+ try:
+ response = self.client.post(
+ api_base,
+ headers=headers,
+ data=json.dumps(data),
+ ) # type: ignore
+
+ response.raise_for_status() # type: ignore
+
+ response_json = response.json() # type: ignore
+ except httpx.HTTPStatusError as e:
+ raise DatabricksError(
+ status_code=e.response.status_code,
+ message=response.text if response else str(e),
+ )
+ except httpx.TimeoutException as e:
+ raise DatabricksError(status_code=408, message="Timeout error occurred.")
+ except Exception as e:
+ raise DatabricksError(status_code=500, message=str(e))
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=response_json,
+ )
+
+ return litellm.EmbeddingResponse(**response_json)
diff --git a/litellm/llms/gemini.py b/litellm/llms/gemini.py
index 60220fd29..a55b39aef 100644
--- a/litellm/llms/gemini.py
+++ b/litellm/llms/gemini.py
@@ -260,7 +260,7 @@ def completion(
message_obj = Message(content=item.content.parts[0].text)
else:
message_obj = Message(content=None)
- choice_obj = Choices(index=idx + 1, message=message_obj)
+ choice_obj = Choices(index=idx, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
@@ -352,7 +352,7 @@ async def async_completion(
message_obj = Message(content=item.content.parts[0].text)
else:
message_obj = Message(content=None)
- choice_obj = Choices(index=idx + 1, message=message_obj)
+ choice_obj = Choices(index=idx, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py
index 9c9b5e898..283878056 100644
--- a/litellm/llms/ollama.py
+++ b/litellm/llms/ollama.py
@@ -45,6 +45,8 @@ class OllamaConfig:
- `temperature` (float): The temperature of the model. Increasing the temperature will make the model answer more creatively. Default: 0.8. Example usage: temperature 0.7
+ - `seed` (int): Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. Example usage: seed 42
+
- `stop` (string[]): Sets the stop sequences to use. Example usage: stop "AI assistant:"
- `tfs_z` (float): Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. Default: 1. Example usage: tfs_z 1
@@ -69,6 +71,7 @@ class OllamaConfig:
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
+ seed: Optional[int] = None
stop: Optional[list] = (
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
)
@@ -90,6 +93,7 @@ class OllamaConfig:
repeat_last_n: Optional[int] = None,
repeat_penalty: Optional[float] = None,
temperature: Optional[float] = None,
+ seed: Optional[int] = None,
stop: Optional[list] = None,
tfs_z: Optional[float] = None,
num_predict: Optional[int] = None,
@@ -120,6 +124,44 @@ class OllamaConfig:
)
and v is not None
}
+ def get_supported_openai_params(
+ self,
+ ):
+ return [
+ "max_tokens",
+ "stream",
+ "top_p",
+ "temperature",
+ "seed",
+ "frequency_penalty",
+ "stop",
+ "response_format",
+ ]
+
+# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
+# and convert to jpeg if necessary.
+def _convert_image(image):
+ import base64, io
+ try:
+ from PIL import Image
+ except:
+ raise Exception(
+ "ollama image conversion failed please run `pip install Pillow`"
+ )
+
+ orig = image
+ if image.startswith("data:"):
+ image = image.split(",")[-1]
+ try:
+ image_data = Image.open(io.BytesIO(base64.b64decode(image)))
+ if image_data.format in ["JPEG", "PNG"]:
+ return image
+ except:
+ return orig
+ jpeg_image = io.BytesIO()
+ image_data.convert("RGB").save(jpeg_image, "JPEG")
+ jpeg_image.seek(0)
+ return base64.b64encode(jpeg_image.getvalue()).decode("utf-8")
# ollama implementation
@@ -158,7 +200,7 @@ def get_ollama_response(
if format is not None:
data["format"] = format
if images is not None:
- data["images"] = images
+ data["images"] = [_convert_image(image) for image in images]
## LOGGING
logging_obj.pre_call(
diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py
index d1ff4953f..a05807722 100644
--- a/litellm/llms/ollama_chat.py
+++ b/litellm/llms/ollama_chat.py
@@ -45,6 +45,8 @@ class OllamaChatConfig:
- `temperature` (float): The temperature of the model. Increasing the temperature will make the model answer more creatively. Default: 0.8. Example usage: temperature 0.7
+ - `seed` (int): Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. Example usage: seed 42
+
- `stop` (string[]): Sets the stop sequences to use. Example usage: stop "AI assistant:"
- `tfs_z` (float): Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. Default: 1. Example usage: tfs_z 1
@@ -69,6 +71,7 @@ class OllamaChatConfig:
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
+ seed: Optional[int] = None
stop: Optional[list] = (
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
)
@@ -90,6 +93,7 @@ class OllamaChatConfig:
repeat_last_n: Optional[int] = None,
repeat_penalty: Optional[float] = None,
temperature: Optional[float] = None,
+ seed: Optional[int] = None,
stop: Optional[list] = None,
tfs_z: Optional[float] = None,
num_predict: Optional[int] = None,
@@ -130,6 +134,7 @@ class OllamaChatConfig:
"stream",
"top_p",
"temperature",
+ "seed",
"frequency_penalty",
"stop",
"tools",
@@ -146,6 +151,8 @@ class OllamaChatConfig:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
+ if param == "seed":
+ optional_params["seed"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "frequency_penalty":
diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py
index 7acbdfae0..e68a50347 100644
--- a/litellm/llms/openai.py
+++ b/litellm/llms/openai.py
@@ -21,11 +21,12 @@ from litellm.utils import (
TranscriptionResponse,
TextCompletionResponse,
)
-from typing import Callable, Optional
+from typing import Callable, Optional, Coroutine
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI
from ..types.llms.openai import *
+import openai
class OpenAIError(Exception):
@@ -96,7 +97,7 @@ class MistralConfig:
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
) -> None:
- locals_ = locals()
+ locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@@ -157,6 +158,102 @@ class MistralConfig:
)
if param == "seed":
optional_params["extra_body"] = {"random_seed": value}
+ if param == "response_format":
+ optional_params["response_format"] = value
+ return optional_params
+
+
+class DeepInfraConfig:
+ """
+ Reference: https://deepinfra.com/docs/advanced/openai_api
+
+ The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
+ """
+
+ frequency_penalty: Optional[int] = None
+ function_call: Optional[Union[str, dict]] = None
+ functions: Optional[list] = None
+ logit_bias: Optional[dict] = None
+ max_tokens: Optional[int] = None
+ n: Optional[int] = None
+ presence_penalty: Optional[int] = None
+ stop: Optional[Union[str, list]] = None
+ temperature: Optional[int] = None
+ top_p: Optional[int] = None
+ response_format: Optional[dict] = None
+ tools: Optional[list] = None
+ tool_choice: Optional[Union[str, dict]] = None
+
+ def __init__(
+ self,
+ frequency_penalty: Optional[int] = None,
+ function_call: Optional[Union[str, dict]] = None,
+ functions: Optional[list] = None,
+ logit_bias: Optional[dict] = None,
+ max_tokens: Optional[int] = None,
+ n: Optional[int] = None,
+ presence_penalty: Optional[int] = None,
+ stop: Optional[Union[str, list]] = None,
+ temperature: Optional[int] = None,
+ top_p: Optional[int] = None,
+ response_format: Optional[dict] = None,
+ tools: Optional[list] = None,
+ tool_choice: Optional[Union[str, dict]] = None,
+ ) -> None:
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_supported_openai_params(self):
+ return [
+ "stream",
+ "frequency_penalty",
+ "function_call",
+ "functions",
+ "logit_bias",
+ "max_tokens",
+ "n",
+ "presence_penalty",
+ "stop",
+ "temperature",
+ "top_p",
+ "response_format",
+ "tools",
+ "tool_choice",
+ ]
+
+ def map_openai_params(
+ self, non_default_params: dict, optional_params: dict, model: str
+ ):
+ supported_openai_params = self.get_supported_openai_params()
+ for param, value in non_default_params.items():
+ if (
+ param == "temperature"
+ and value == 0
+ and model == "mistralai/Mistral-7B-Instruct-v0.1"
+ ): # this model does no support temperature == 0
+ value = 0.0001 # close to 0
+ if param in supported_openai_params:
+ optional_params[param] = value
return optional_params
@@ -197,6 +294,7 @@ class OpenAIConfig:
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
+ response_format: Optional[dict] = None
def __init__(
self,
@@ -210,8 +308,9 @@ class OpenAIConfig:
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
+ response_format: Optional[dict] = None,
) -> None:
- locals_ = locals()
+ locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@@ -234,6 +333,52 @@ class OpenAIConfig:
and v is not None
}
+ def get_supported_openai_params(self, model: str) -> list:
+ base_params = [
+ "frequency_penalty",
+ "logit_bias",
+ "logprobs",
+ "top_logprobs",
+ "max_tokens",
+ "n",
+ "presence_penalty",
+ "seed",
+ "stop",
+ "stream",
+ "stream_options",
+ "temperature",
+ "top_p",
+ "tools",
+ "tool_choice",
+ "function_call",
+ "functions",
+ "max_retries",
+ "extra_headers",
+ ] # works across all models
+
+ model_specific_params = []
+ if (
+ model != "gpt-3.5-turbo-16k" and model != "gpt-4"
+ ): # gpt-4 does not support 'response_format'
+ model_specific_params.append("response_format")
+
+ if (
+ model in litellm.open_ai_chat_completion_models
+ ) or model in litellm.open_ai_text_completion_models:
+ model_specific_params.append(
+ "user"
+ ) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
+ return base_params + model_specific_params
+
+ def map_openai_params(
+ self, non_default_params: dict, optional_params: dict, model: str
+ ) -> dict:
+ supported_openai_params = self.get_supported_openai_params(model)
+ for param, value in non_default_params.items():
+ if param in supported_openai_params:
+ optional_params[param] = value
+ return optional_params
+
class OpenAITextCompletionConfig:
"""
@@ -294,7 +439,7 @@ class OpenAITextCompletionConfig:
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> None:
- locals_ = locals()
+ locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@@ -363,6 +508,7 @@ class OpenAIChatCompletion(BaseLLM):
self,
model_response: ModelResponse,
timeout: Union[float, httpx.Timeout],
+ optional_params: dict,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable] = None,
@@ -370,7 +516,6 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str] = None,
acompletion: bool = False,
logging_obj=None,
- optional_params=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
@@ -754,10 +899,10 @@ class OpenAIChatCompletion(BaseLLM):
model: str,
input: list,
timeout: float,
+ logging_obj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
- logging_obj=None,
optional_params=None,
client=None,
aembedding=None,
@@ -946,8 +1091,8 @@ class OpenAIChatCompletion(BaseLLM):
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
- api_key: Optional[str] = None,
- api_base: Optional[str] = None,
+ api_key: Optional[str],
+ api_base: Optional[str],
client=None,
logging_obj=None,
atranscription: bool = False,
@@ -1003,7 +1148,6 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=None,
logging_obj=None,
):
- response = None
try:
if client is None:
openai_aclient = AsyncOpenAI(
@@ -1037,6 +1181,95 @@ class OpenAIChatCompletion(BaseLLM):
)
raise e
+ def audio_speech(
+ self,
+ model: str,
+ input: str,
+ voice: str,
+ optional_params: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ organization: Optional[str],
+ project: Optional[str],
+ max_retries: int,
+ timeout: Union[float, httpx.Timeout],
+ aspeech: Optional[bool] = None,
+ client=None,
+ ) -> HttpxBinaryResponseContent:
+
+ if aspeech is not None and aspeech == True:
+ return self.async_audio_speech(
+ model=model,
+ input=input,
+ voice=voice,
+ optional_params=optional_params,
+ api_key=api_key,
+ api_base=api_base,
+ organization=organization,
+ project=project,
+ max_retries=max_retries,
+ timeout=timeout,
+ client=client,
+ ) # type: ignore
+
+ if client is None:
+ openai_client = OpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ organization=organization,
+ project=project,
+ http_client=litellm.client_session,
+ timeout=timeout,
+ max_retries=max_retries,
+ )
+ else:
+ openai_client = client
+
+ response = openai_client.audio.speech.create(
+ model=model,
+ voice=voice, # type: ignore
+ input=input,
+ **optional_params,
+ )
+ return response
+
+ async def async_audio_speech(
+ self,
+ model: str,
+ input: str,
+ voice: str,
+ optional_params: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ organization: Optional[str],
+ project: Optional[str],
+ max_retries: int,
+ timeout: Union[float, httpx.Timeout],
+ client=None,
+ ) -> HttpxBinaryResponseContent:
+
+ if client is None:
+ openai_client = AsyncOpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ organization=organization,
+ project=project,
+ http_client=litellm.aclient_session,
+ timeout=timeout,
+ max_retries=max_retries,
+ )
+ else:
+ openai_client = client
+
+ response = await openai_client.audio.speech.create(
+ model=model,
+ voice=voice, # type: ignore
+ input=input,
+ **optional_params,
+ )
+
+ return response
+
async def ahealth_check(
self,
model: Optional[str],
@@ -1358,6 +1591,322 @@ class OpenAITextCompletion(BaseLLM):
yield transformed_chunk
+class OpenAIFilesAPI(BaseLLM):
+ """
+ OpenAI methods to support for batches
+ - create_file()
+ - retrieve_file()
+ - list_files()
+ - delete_file()
+ - file_content()
+ - update_file()
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def get_openai_client(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ _is_async: bool = False,
+ ) -> Optional[Union[OpenAI, AsyncOpenAI]]:
+ received_args = locals()
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
+ if client is None:
+ data = {}
+ for k, v in received_args.items():
+ if k == "self" or k == "client" or k == "_is_async":
+ pass
+ elif k == "api_base" and v is not None:
+ data["base_url"] = v
+ elif v is not None:
+ data[k] = v
+ if _is_async is True:
+ openai_client = AsyncOpenAI(**data)
+ else:
+ openai_client = OpenAI(**data) # type: ignore
+ else:
+ openai_client = client
+
+ return openai_client
+
+ async def acreate_file(
+ self,
+ create_file_data: CreateFileRequest,
+ openai_client: AsyncOpenAI,
+ ) -> FileObject:
+ response = await openai_client.files.create(**create_file_data)
+ return response
+
+ def create_file(
+ self,
+ _is_async: bool,
+ create_file_data: CreateFileRequest,
+ api_base: str,
+ api_key: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ ) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+
+ if _is_async is True:
+ if not isinstance(openai_client, AsyncOpenAI):
+ raise ValueError(
+ "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
+ )
+ return self.acreate_file( # type: ignore
+ create_file_data=create_file_data, openai_client=openai_client
+ )
+ response = openai_client.files.create(**create_file_data)
+ return response
+
+ async def afile_content(
+ self,
+ file_content_request: FileContentRequest,
+ openai_client: AsyncOpenAI,
+ ) -> HttpxBinaryResponseContent:
+ response = await openai_client.files.content(**file_content_request)
+ return response
+
+ def file_content(
+ self,
+ _is_async: bool,
+ file_content_request: FileContentRequest,
+ api_base: str,
+ api_key: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ ) -> Union[
+ HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
+ ]:
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+
+ if _is_async is True:
+ if not isinstance(openai_client, AsyncOpenAI):
+ raise ValueError(
+ "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
+ )
+ return self.afile_content( # type: ignore
+ file_content_request=file_content_request,
+ openai_client=openai_client,
+ )
+ response = openai_client.files.content(**file_content_request)
+
+ return response
+
+
+class OpenAIBatchesAPI(BaseLLM):
+ """
+ OpenAI methods to support for batches
+ - create_batch()
+ - retrieve_batch()
+ - cancel_batch()
+ - list_batch()
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def get_openai_client(
+ self,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ _is_async: bool = False,
+ ) -> Optional[Union[OpenAI, AsyncOpenAI]]:
+ received_args = locals()
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
+ if client is None:
+ data = {}
+ for k, v in received_args.items():
+ if k == "self" or k == "client" or k == "_is_async":
+ pass
+ elif k == "api_base" and v is not None:
+ data["base_url"] = v
+ elif v is not None:
+ data[k] = v
+ if _is_async is True:
+ openai_client = AsyncOpenAI(**data)
+ else:
+ openai_client = OpenAI(**data) # type: ignore
+ else:
+ openai_client = client
+
+ return openai_client
+
+ async def acreate_batch(
+ self,
+ create_batch_data: CreateBatchRequest,
+ openai_client: AsyncOpenAI,
+ ) -> Batch:
+ response = await openai_client.batches.create(**create_batch_data)
+ return response
+
+ def create_batch(
+ self,
+ _is_async: bool,
+ create_batch_data: CreateBatchRequest,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
+ ) -> Union[Batch, Coroutine[Any, Any, Batch]]:
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+
+ if _is_async is True:
+ if not isinstance(openai_client, AsyncOpenAI):
+ raise ValueError(
+ "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
+ )
+ return self.acreate_batch( # type: ignore
+ create_batch_data=create_batch_data, openai_client=openai_client
+ )
+ response = openai_client.batches.create(**create_batch_data)
+ return response
+
+ async def aretrieve_batch(
+ self,
+ retrieve_batch_data: RetrieveBatchRequest,
+ openai_client: AsyncOpenAI,
+ ) -> Batch:
+ response = await openai_client.batches.retrieve(**retrieve_batch_data)
+ return response
+
+ def retrieve_batch(
+ self,
+ _is_async: bool,
+ retrieve_batch_data: RetrieveBatchRequest,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI] = None,
+ ):
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+
+ if _is_async is True:
+ if not isinstance(openai_client, AsyncOpenAI):
+ raise ValueError(
+ "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
+ )
+ return self.aretrieve_batch( # type: ignore
+ retrieve_batch_data=retrieve_batch_data, openai_client=openai_client
+ )
+ response = openai_client.batches.retrieve(**retrieve_batch_data)
+ return response
+
+ def cancel_batch(
+ self,
+ _is_async: bool,
+ cancel_batch_data: CancelBatchRequest,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ organization: Optional[str],
+ client: Optional[OpenAI] = None,
+ ):
+ openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ client=client,
+ _is_async=_is_async,
+ )
+ if openai_client is None:
+ raise ValueError(
+ "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
+ )
+ response = openai_client.batches.cancel(**cancel_batch_data)
+ return response
+
+ # def list_batch(
+ # self,
+ # list_batch_data: ListBatchRequest,
+ # api_key: Optional[str],
+ # api_base: Optional[str],
+ # timeout: Union[float, httpx.Timeout],
+ # max_retries: Optional[int],
+ # organization: Optional[str],
+ # client: Optional[OpenAI] = None,
+ # ):
+ # openai_client: OpenAI = self.get_openai_client(
+ # api_key=api_key,
+ # api_base=api_base,
+ # timeout=timeout,
+ # max_retries=max_retries,
+ # organization=organization,
+ # client=client,
+ # )
+ # response = openai_client.batches.list(**list_batch_data)
+ # return response
+
+
class OpenAIAssistantsAPI(BaseLLM):
def __init__(self) -> None:
super().__init__()
diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py
index cf593369c..41ecb486c 100644
--- a/litellm/llms/prompt_templates/factory.py
+++ b/litellm/llms/prompt_templates/factory.py
@@ -12,6 +12,7 @@ from typing import (
Sequence,
)
import litellm
+import litellm.types
from litellm.types.completion import (
ChatCompletionUserMessageParam,
ChatCompletionSystemMessageParam,
@@ -20,9 +21,12 @@ from litellm.types.completion import (
ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam,
)
+import litellm.types.llms
from litellm.types.llms.anthropic import *
import uuid
+import litellm.types.llms.vertex_ai
+
def default_pt(messages):
return " ".join(message["content"] for message in messages)
@@ -111,6 +115,26 @@ def llama_2_chat_pt(messages):
return prompt
+def convert_to_ollama_image(openai_image_url: str):
+ try:
+ if openai_image_url.startswith("http"):
+ openai_image_url = convert_url_to_base64(url=openai_image_url)
+
+ if openai_image_url.startswith("data:image/"):
+ # Extract the base64 image data
+ base64_data = openai_image_url.split("data:image/")[1].split(";base64,")[1]
+ else:
+ base64_data = openai_image_url
+
+ return base64_data
+ except Exception as e:
+ if "Error: Unable to fetch image from URL" in str(e):
+ raise e
+ raise Exception(
+ """Image url not in expected format. Example Expected input - "image_url": "data:image/jpeg;base64,{base64_image}". """
+ )
+
+
def ollama_pt(
model, messages
): # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
@@ -143,8 +167,10 @@ def ollama_pt(
if element["type"] == "text":
prompt += element["text"]
elif element["type"] == "image_url":
- image_url = element["image_url"]["url"]
- images.append(image_url)
+ base64_image = convert_to_ollama_image(
+ element["image_url"]["url"]
+ )
+ images.append(base64_image)
return {"prompt": prompt, "images": images}
else:
prompt = "".join(
@@ -841,6 +867,175 @@ def anthropic_messages_pt_xml(messages: list):
# ------------------------------------------------------------------------------
+def infer_protocol_value(
+ value: Any,
+) -> Literal[
+ "string_value",
+ "number_value",
+ "bool_value",
+ "struct_value",
+ "list_value",
+ "null_value",
+ "unknown",
+]:
+ if value is None:
+ return "null_value"
+ if isinstance(value, int) or isinstance(value, float):
+ return "number_value"
+ if isinstance(value, str):
+ return "string_value"
+ if isinstance(value, bool):
+ return "bool_value"
+ if isinstance(value, dict):
+ return "struct_value"
+ if isinstance(value, list):
+ return "list_value"
+
+ return "unknown"
+
+
+def convert_to_gemini_tool_call_invoke(
+ tool_calls: list,
+) -> List[litellm.types.llms.vertex_ai.PartType]:
+ """
+ OpenAI tool invokes:
+ {
+ "role": "assistant",
+ "content": null,
+ "tool_calls": [
+ {
+ "id": "call_abc123",
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "arguments": "{\n\"location\": \"Boston, MA\"\n}"
+ }
+ }
+ ]
+ },
+ """
+ """
+ Gemini tool call invokes: - https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#submit-api-output
+ content {
+ role: "model"
+ parts [
+ {
+ function_call {
+ name: "get_current_weather"
+ args {
+ fields {
+ key: "unit"
+ value {
+ string_value: "fahrenheit"
+ }
+ }
+ fields {
+ key: "predicted_temperature"
+ value {
+ number_value: 45
+ }
+ }
+ fields {
+ key: "location"
+ value {
+ string_value: "Boston, MA"
+ }
+ }
+ }
+ },
+ {
+ function_call {
+ name: "get_current_weather"
+ args {
+ fields {
+ key: "location"
+ value {
+ string_value: "San Francisco"
+ }
+ }
+ }
+ }
+ }
+ ]
+ }
+ """
+
+ """
+ - json.load the arguments
+ - iterate through arguments -> create a FunctionCallArgs for each field
+ """
+ try:
+ _parts_list: List[litellm.types.llms.vertex_ai.PartType] = []
+ for tool in tool_calls:
+ if "function" in tool:
+ name = tool["function"].get("name", "")
+ arguments = tool["function"].get("arguments", "")
+ arguments_dict = json.loads(arguments)
+ for k, v in arguments_dict.items():
+ inferred_protocol_value = infer_protocol_value(value=v)
+ _field = litellm.types.llms.vertex_ai.Field(
+ key=k, value={inferred_protocol_value: v}
+ )
+ _fields = litellm.types.llms.vertex_ai.FunctionCallArgs(
+ fields=_field
+ )
+ function_call = litellm.types.llms.vertex_ai.FunctionCall(
+ name=name,
+ args=_fields,
+ )
+ _parts_list.append(
+ litellm.types.llms.vertex_ai.PartType(function_call=function_call)
+ )
+ return _parts_list
+ except Exception as e:
+ raise Exception(
+ "Unable to convert openai tool calls={} to gemini tool calls. Received error={}".format(
+ tool_calls, str(e)
+ )
+ )
+
+
+def convert_to_gemini_tool_call_result(
+ message: dict,
+) -> litellm.types.llms.vertex_ai.PartType:
+ """
+ OpenAI message with a tool result looks like:
+ {
+ "tool_call_id": "tool_1",
+ "role": "tool",
+ "name": "get_current_weather",
+ "content": "function result goes here",
+ },
+
+ OpenAI message with a function call result looks like:
+ {
+ "role": "function",
+ "name": "get_current_weather",
+ "content": "function result goes here",
+ }
+ """
+ content = message.get("content", "")
+ name = message.get("name", "")
+
+ # We can't determine from openai message format whether it's a successful or
+ # error call result so default to the successful result template
+ inferred_content_value = infer_protocol_value(value=content)
+
+ _field = litellm.types.llms.vertex_ai.Field(
+ key="content", value={inferred_content_value: content}
+ )
+
+ _function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
+
+ _function_response = litellm.types.llms.vertex_ai.FunctionResponse(
+ name=name, response=_function_call_args
+ )
+
+ _part = litellm.types.llms.vertex_ai.PartType(function_response=_function_response)
+
+ return _part
+
+
def convert_to_anthropic_tool_result(message: dict) -> dict:
"""
OpenAI message with a tool result looks like:
@@ -1328,6 +1523,7 @@ def _gemini_vision_convert_messages(messages: list):
# Case 1: Image from URL
image = _load_image_from_url(img)
processed_images.append(image)
+
else:
try:
from PIL import Image
@@ -1335,8 +1531,23 @@ def _gemini_vision_convert_messages(messages: list):
raise Exception(
"gemini image conversion failed please run `pip install Pillow`"
)
- # Case 2: Image filepath (e.g. temp.jpeg) given
- image = Image.open(img)
+
+ if "base64" in img:
+ # Case 2: Base64 image data
+ import base64
+ import io
+
+ # Extract the base64 image data
+ base64_data = img.split("base64,")[1]
+
+ # Decode the base64 image data
+ image_data = base64.b64decode(base64_data)
+
+ # Load the image from the decoded data
+ image = Image.open(io.BytesIO(image_data))
+ else:
+ # Case 3: Image filepath (e.g. temp.jpeg) given
+ image = Image.open(img)
processed_images.append(image)
content = [prompt] + processed_images
return content
@@ -1513,7 +1724,7 @@ def prompt_factory(
elif custom_llm_provider == "clarifai":
if "claude" in model:
return anthropic_pt(messages=messages)
-
+
elif custom_llm_provider == "perplexity":
for message in messages:
message.pop("name", None)
diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py
index c29728134..386d24f59 100644
--- a/litellm/llms/replicate.py
+++ b/litellm/llms/replicate.py
@@ -2,11 +2,12 @@ import os, types
import json
import requests # type: ignore
import time
-from typing import Callable, Optional
-from litellm.utils import ModelResponse, Usage
-import litellm
+from typing import Callable, Optional, Union, Tuple, Any
+from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
+import litellm, asyncio
import httpx # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
class ReplicateError(Exception):
@@ -145,6 +146,65 @@ def start_prediction(
)
+async def async_start_prediction(
+ version_id,
+ input_data,
+ api_token,
+ api_base,
+ logging_obj,
+ print_verbose,
+ http_handler: AsyncHTTPHandler,
+) -> str:
+ base_url = api_base
+ if "deployments" in version_id:
+ print_verbose("\nLiteLLM: Request to custom replicate deployment")
+ version_id = version_id.replace("deployments/", "")
+ base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
+ print_verbose(f"Deployment base URL: {base_url}\n")
+ else: # assume it's a model
+ base_url = f"https://api.replicate.com/v1/models/{version_id}"
+ headers = {
+ "Authorization": f"Token {api_token}",
+ "Content-Type": "application/json",
+ }
+
+ initial_prediction_data = {
+ "input": input_data,
+ }
+
+ if ":" in version_id and len(version_id) > 64:
+ model_parts = version_id.split(":")
+ if (
+ len(model_parts) > 1 and len(model_parts[1]) == 64
+ ): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
+ initial_prediction_data["version"] = model_parts[1]
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=input_data["prompt"],
+ api_key="",
+ additional_args={
+ "complete_input_dict": initial_prediction_data,
+ "headers": headers,
+ "api_base": base_url,
+ },
+ )
+
+ response = await http_handler.post(
+ url="{}/predictions".format(base_url),
+ data=json.dumps(initial_prediction_data),
+ headers=headers,
+ )
+
+ if response.status_code == 201:
+ response_data = response.json()
+ return response_data.get("urls", {}).get("get")
+ else:
+ raise ReplicateError(
+ response.status_code, f"Failed to start prediction {response.text}"
+ )
+
+
# Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose):
output_string = ""
@@ -178,6 +238,40 @@ def handle_prediction_response(prediction_url, api_token, print_verbose):
return output_string, logs
+async def async_handle_prediction_response(
+ prediction_url, api_token, print_verbose, http_handler: AsyncHTTPHandler
+) -> Tuple[str, Any]:
+ output_string = ""
+ headers = {
+ "Authorization": f"Token {api_token}",
+ "Content-Type": "application/json",
+ }
+
+ status = ""
+ logs = ""
+ while True and (status not in ["succeeded", "failed", "canceled"]):
+ print_verbose(f"replicate: polling endpoint: {prediction_url}")
+ await asyncio.sleep(0.5)
+ response = await http_handler.get(prediction_url, headers=headers)
+ if response.status_code == 200:
+ response_data = response.json()
+ if "output" in response_data:
+ output_string = "".join(response_data["output"])
+ print_verbose(f"Non-streamed output:{output_string}")
+ status = response_data.get("status", None)
+ logs = response_data.get("logs", "")
+ if status == "failed":
+ replicate_error = response_data.get("error", "")
+ raise ReplicateError(
+ status_code=400,
+ message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
+ )
+ else:
+ # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
+ print_verbose("Replicate: Failed to fetch prediction status and output.")
+ return output_string, logs
+
+
# Function to handle prediction response (streaming)
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
previous_output = ""
@@ -214,6 +308,45 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
)
+# Function to handle prediction response (streaming)
+async def async_handle_prediction_response_streaming(
+ prediction_url, api_token, print_verbose
+):
+ http_handler = AsyncHTTPHandler(concurrent_limit=1)
+ previous_output = ""
+ output_string = ""
+
+ headers = {
+ "Authorization": f"Token {api_token}",
+ "Content-Type": "application/json",
+ }
+ status = ""
+ while True and (status not in ["succeeded", "failed", "canceled"]):
+ await asyncio.sleep(0.5) # prevent being rate limited by replicate
+ print_verbose(f"replicate: polling endpoint: {prediction_url}")
+ response = await http_handler.get(prediction_url, headers=headers)
+ if response.status_code == 200:
+ response_data = response.json()
+ status = response_data["status"]
+ if "output" in response_data:
+ output_string = "".join(response_data["output"])
+ new_output = output_string[len(previous_output) :]
+ print_verbose(f"New chunk: {new_output}")
+ yield {"output": new_output, "status": status}
+ previous_output = output_string
+ status = response_data["status"]
+ if status == "failed":
+ replicate_error = response_data.get("error", "")
+ raise ReplicateError(
+ status_code=400, message=f"Error: {replicate_error}"
+ )
+ else:
+ # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
+ print_verbose(
+ f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
+ )
+
+
# Function to extract version ID from model string
def model_to_version_id(model):
if ":" in model:
@@ -222,6 +355,39 @@ def model_to_version_id(model):
return model
+def process_response(
+ model_response: ModelResponse,
+ result: str,
+ model: str,
+ encoding: Any,
+ prompt: str,
+) -> ModelResponse:
+ if len(result) == 0: # edge case, where result from replicate is empty
+ result = " "
+
+ ## Building RESPONSE OBJECT
+ if len(result) > 1:
+ model_response["choices"][0]["message"]["content"] = result
+
+ # Calculate usage
+ prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
+ completion_tokens = len(
+ encoding.encode(
+ model_response["choices"][0]["message"].get("content", ""),
+ disallowed_special=(),
+ )
+ )
+ model_response["model"] = "replicate/" + model
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ )
+ setattr(model_response, "usage", usage)
+
+ return model_response
+
+
# Main function for prediction completion
def completion(
model: str,
@@ -229,14 +395,15 @@ def completion(
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
+ optional_params: dict,
logging_obj,
api_key,
encoding,
custom_prompt_dict={},
- optional_params=None,
litellm_params=None,
logger_fn=None,
-):
+ acompletion=None,
+) -> Union[ModelResponse, CustomStreamWrapper]:
# Start a prediction and get the prediction URL
version_id = model_to_version_id(model)
## Load Config
@@ -274,6 +441,12 @@ def completion(
else:
prompt = prompt_factory(model=model, messages=messages)
+ if prompt is None or not isinstance(prompt, str):
+ raise ReplicateError(
+ status_code=400,
+ message="LiteLLM Error - prompt is not a string - {}".format(prompt),
+ )
+
# If system prompt is supported, and a system prompt is provided, use it
if system_prompt is not None:
input_data = {
@@ -285,6 +458,20 @@ def completion(
else:
input_data = {"prompt": prompt, **optional_params}
+ if acompletion is not None and acompletion == True:
+ return async_completion(
+ model_response=model_response,
+ model=model,
+ prompt=prompt,
+ encoding=encoding,
+ optional_params=optional_params,
+ version_id=version_id,
+ input_data=input_data,
+ api_key=api_key,
+ api_base=api_base,
+ logging_obj=logging_obj,
+ print_verbose=print_verbose,
+ ) # type: ignore
## COMPLETION CALL
## Replicate Compeltion calls have 2 steps
## Step1: Start Prediction: gets a prediction url
@@ -293,6 +480,7 @@ def completion(
model_response["created"] = int(
time.time()
) # for pricing this must remain right before calling api
+
prediction_url = start_prediction(
version_id,
input_data,
@@ -306,9 +494,10 @@ def completion(
# Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] == True:
print_verbose("streaming request")
- return handle_prediction_response_streaming(
+ _response = handle_prediction_response_streaming(
prediction_url, api_key, print_verbose
)
+ return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
else:
result, logs = handle_prediction_response(
prediction_url, api_key, print_verbose
@@ -328,29 +517,56 @@ def completion(
print_verbose(f"raw model_response: {result}")
- if len(result) == 0: # edge case, where result from replicate is empty
- result = " "
-
- ## Building RESPONSE OBJECT
- if len(result) > 1:
- model_response["choices"][0]["message"]["content"] = result
-
- # Calculate usage
- prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
- completion_tokens = len(
- encoding.encode(
- model_response["choices"][0]["message"].get("content", ""),
- disallowed_special=(),
- )
+ return process_response(
+ model_response=model_response,
+ result=result,
+ model=model,
+ encoding=encoding,
+ prompt=prompt,
)
- model_response["model"] = "replicate/" + model
- usage = Usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=prompt_tokens + completion_tokens,
+
+
+async def async_completion(
+ model_response: ModelResponse,
+ model: str,
+ prompt: str,
+ encoding,
+ optional_params: dict,
+ version_id,
+ input_data,
+ api_key,
+ api_base,
+ logging_obj,
+ print_verbose,
+) -> Union[ModelResponse, CustomStreamWrapper]:
+ http_handler = AsyncHTTPHandler(concurrent_limit=1)
+ prediction_url = await async_start_prediction(
+ version_id,
+ input_data,
+ api_key,
+ api_base,
+ logging_obj=logging_obj,
+ print_verbose=print_verbose,
+ http_handler=http_handler,
+ )
+
+ if "stream" in optional_params and optional_params["stream"] == True:
+ _response = async_handle_prediction_response_streaming(
+ prediction_url, api_key, print_verbose
)
- setattr(model_response, "usage", usage)
- return model_response
+ return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
+
+ result, logs = await async_handle_prediction_response(
+ prediction_url, api_key, print_verbose, http_handler=http_handler
+ )
+
+ return process_response(
+ model_response=model_response,
+ result=result,
+ model=model,
+ encoding=encoding,
+ prompt=prompt,
+ )
# # Example usage:
diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py
index 84fec734f..dc185aef9 100644
--- a/litellm/llms/vertex_ai.py
+++ b/litellm/llms/vertex_ai.py
@@ -3,10 +3,15 @@ import json
from enum import Enum
import requests # type: ignore
import time
-from typing import Callable, Optional, Union, List
+from typing import Callable, Optional, Union, List, Literal
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
+from litellm.types.llms.vertex_ai import *
+from litellm.llms.prompt_templates.factory import (
+ convert_to_gemini_tool_call_result,
+ convert_to_gemini_tool_call_invoke,
+)
class VertexAIError(Exception):
@@ -283,6 +288,139 @@ def _load_image_from_url(image_url: str):
return Image.from_bytes(data=image_bytes)
+def _convert_gemini_role(role: str) -> Literal["user", "model"]:
+ if role == "user":
+ return "user"
+ else:
+ return "model"
+
+
+def _process_gemini_image(image_url: str) -> PartType:
+ try:
+ if "gs://" in image_url:
+ # Case 1: Images with Cloud Storage URIs
+ # The supported MIME types for images include image/png and image/jpeg.
+ part_mime = "image/png" if "png" in image_url else "image/jpeg"
+ _file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
+ return PartType(file_data=_file_data)
+ elif "https:/" in image_url:
+ # Case 2: Images with direct links
+ image = _load_image_from_url(image_url)
+ _blob = BlobType(data=image.data, mime_type=image._mime_type)
+ return PartType(inline_data=_blob)
+ elif ".mp4" in image_url and "gs://" in image_url:
+ # Case 3: Videos with Cloud Storage URIs
+ part_mime = "video/mp4"
+ _file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
+ return PartType(file_data=_file_data)
+ elif "base64" in image_url:
+ # Case 4: Images with base64 encoding
+ import base64, re
+
+ # base 64 is passed as data:image/jpeg;base64,
+ image_metadata, img_without_base_64 = image_url.split(",")
+
+ # read mime_type from img_without_base_64=data:image/jpeg;base64
+ # Extract MIME type using regular expression
+ mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
+
+ if mime_type_match:
+ mime_type = mime_type_match.group(1)
+ else:
+ mime_type = "image/jpeg"
+ decoded_img = base64.b64decode(img_without_base_64)
+ _blob = BlobType(data=decoded_img, mime_type=mime_type)
+ return PartType(inline_data=_blob)
+ raise Exception("Invalid image received - {}".format(image_url))
+ except Exception as e:
+ raise e
+
+
+def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
+ """
+ Converts given messages from OpenAI format to Gemini format
+
+ - Parts must be iterable
+ - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
+ - Please ensure that function response turn comes immediately after a function call turn
+ """
+ user_message_types = {"user", "system"}
+ contents: List[ContentType] = []
+
+ msg_i = 0
+ while msg_i < len(messages):
+ user_content: List[PartType] = []
+ init_msg_i = msg_i
+ ## MERGE CONSECUTIVE USER CONTENT ##
+ while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
+ if isinstance(messages[msg_i]["content"], list):
+ _parts: List[PartType] = []
+ for element in messages[msg_i]["content"]:
+ if isinstance(element, dict):
+ if element["type"] == "text":
+ _part = PartType(text=element["text"])
+ _parts.append(_part)
+ elif element["type"] == "image_url":
+ image_url = element["image_url"]["url"]
+ _part = _process_gemini_image(image_url=image_url)
+ _parts.append(_part) # type: ignore
+ user_content.extend(_parts)
+ else:
+ _part = PartType(text=messages[msg_i]["content"])
+ user_content.append(_part)
+
+ msg_i += 1
+
+ if user_content:
+ contents.append(ContentType(role="user", parts=user_content))
+ assistant_content = []
+ ## MERGE CONSECUTIVE ASSISTANT CONTENT ##
+ while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
+ if isinstance(messages[msg_i]["content"], list):
+ _parts = []
+ for element in messages[msg_i]["content"]:
+ if isinstance(element, dict):
+ if element["type"] == "text":
+ _part = PartType(text=element["text"])
+ _parts.append(_part)
+ elif element["type"] == "image_url":
+ image_url = element["image_url"]["url"]
+ _part = _process_gemini_image(image_url=image_url)
+ _parts.append(_part) # type: ignore
+ assistant_content.extend(_parts)
+ elif messages[msg_i].get(
+ "tool_calls", []
+ ): # support assistant tool invoke convertion
+ assistant_content.extend(
+ convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"])
+ )
+ else:
+ assistant_text = (
+ messages[msg_i].get("content") or ""
+ ) # either string or none
+ if assistant_text:
+ assistant_content.append(PartType(text=assistant_text))
+
+ msg_i += 1
+
+ if assistant_content:
+ contents.append(ContentType(role="model", parts=assistant_content))
+
+ ## APPEND TOOL CALL MESSAGES ##
+ if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
+ _part = convert_to_gemini_tool_call_result(messages[msg_i])
+ contents.append(ContentType(parts=[_part])) # type: ignore
+ msg_i += 1
+ if msg_i == init_msg_i: # prevent infinite loops
+ raise Exception(
+ "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
+ messages[msg_i]
+ )
+ )
+
+ return contents
+
+
def _gemini_vision_convert_messages(messages: list):
"""
Converts given messages for GPT-4 Vision to Gemini format.
@@ -396,10 +534,10 @@ def completion(
print_verbose: Callable,
encoding,
logging_obj,
+ optional_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
- optional_params=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
@@ -556,6 +694,7 @@ def completion(
"model_response": model_response,
"encoding": encoding,
"messages": messages,
+ "request_str": request_str,
"print_verbose": print_verbose,
"client_options": client_options,
"instances": instances,
@@ -574,11 +713,9 @@ def completion(
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
- prompt, images = _gemini_vision_convert_messages(messages=messages)
- content = [prompt] + images
+ content = _gemini_convert_messages_with_history(messages=messages)
stream = optional_params.pop("stream", False)
if stream == True:
-
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call(
input=prompt,
@@ -589,7 +726,7 @@ def completion(
},
)
- model_response = llm_model.generate_content(
+ _model_response = llm_model.generate_content(
contents=content,
generation_config=optional_params,
safety_settings=safety_settings,
@@ -597,7 +734,7 @@ def completion(
tools=tools,
)
- return model_response
+ return _model_response
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
@@ -850,12 +987,12 @@ async def async_completion(
mode: str,
prompt: str,
model: str,
+ messages: list,
model_response: ModelResponse,
- logging_obj=None,
- request_str=None,
+ request_str: str,
+ print_verbose: Callable,
+ logging_obj,
encoding=None,
- messages=None,
- print_verbose=None,
client_options=None,
instances=None,
vertex_project=None,
@@ -875,8 +1012,7 @@ async def async_completion(
tools = optional_params.pop("tools", None)
stream = optional_params.pop("stream", False)
- prompt, images = _gemini_vision_convert_messages(messages=messages)
- content = [prompt] + images
+ content = _gemini_convert_messages_with_history(messages=messages)
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
@@ -1076,11 +1212,11 @@ async def async_streaming(
prompt: str,
model: str,
model_response: ModelResponse,
- logging_obj=None,
- request_str=None,
+ messages: list,
+ print_verbose: Callable,
+ logging_obj,
+ request_str: str,
encoding=None,
- messages=None,
- print_verbose=None,
client_options=None,
instances=None,
vertex_project=None,
@@ -1097,8 +1233,8 @@ async def async_streaming(
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
- prompt, images = _gemini_vision_convert_messages(messages=messages)
- content = [prompt] + images
+ content = _gemini_convert_messages_with_history(messages=messages)
+
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(
input=prompt,
diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py
index 3bdcf4fd6..065294280 100644
--- a/litellm/llms/vertex_ai_anthropic.py
+++ b/litellm/llms/vertex_ai_anthropic.py
@@ -35,7 +35,7 @@ class VertexAIError(Exception):
class VertexAIAnthropicConfig:
"""
- Reference: https://docs.anthropic.com/claude/reference/messages_post
+ Reference:https://docs.anthropic.com/claude/reference/messages_post
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py
new file mode 100644
index 000000000..b8c698c90
--- /dev/null
+++ b/litellm/llms/vertex_httpx.py
@@ -0,0 +1,224 @@
+import os, types
+import json
+from enum import Enum
+import requests # type: ignore
+import time
+from typing import Callable, Optional, Union, List, Any, Tuple
+from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
+import litellm, uuid
+import httpx, inspect # type: ignore
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from .base import BaseLLM
+
+
+class VertexAIError(Exception):
+ def __init__(self, status_code, message):
+ self.status_code = status_code
+ self.message = message
+ self.request = httpx.Request(
+ method="POST", url=" https://cloud.google.com/vertex-ai/"
+ )
+ self.response = httpx.Response(status_code=status_code, request=self.request)
+ super().__init__(
+ self.message
+ ) # Call the base class constructor with the parameters it needs
+
+
+class VertexLLM(BaseLLM):
+ def __init__(self) -> None:
+ super().__init__()
+ self.access_token: Optional[str] = None
+ self.refresh_token: Optional[str] = None
+ self._credentials: Optional[Any] = None
+ self.project_id: Optional[str] = None
+ self.async_handler: Optional[AsyncHTTPHandler] = None
+
+ def load_auth(self) -> Tuple[Any, str]:
+ from google.auth.transport.requests import Request # type: ignore[import-untyped]
+ from google.auth.credentials import Credentials # type: ignore[import-untyped]
+ import google.auth as google_auth
+
+ credentials, project_id = google_auth.default(
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+
+ credentials.refresh(Request())
+
+ if not project_id:
+ raise ValueError("Could not resolve project_id")
+
+ if not isinstance(project_id, str):
+ raise TypeError(
+ f"Expected project_id to be a str but got {type(project_id)}"
+ )
+
+ return credentials, project_id
+
+ def refresh_auth(self, credentials: Any) -> None:
+ from google.auth.transport.requests import Request # type: ignore[import-untyped]
+
+ credentials.refresh(Request())
+
+ def _prepare_request(self, request: httpx.Request) -> None:
+ access_token = self._ensure_access_token()
+
+ if request.headers.get("Authorization"):
+ # already authenticated, nothing for us to do
+ return
+
+ request.headers["Authorization"] = f"Bearer {access_token}"
+
+ def _ensure_access_token(self) -> str:
+ if self.access_token is not None:
+ return self.access_token
+
+ if not self._credentials:
+ self._credentials, project_id = self.load_auth()
+ if not self.project_id:
+ self.project_id = project_id
+ else:
+ self.refresh_auth(self._credentials)
+
+ if not self._credentials.token:
+ raise RuntimeError("Could not resolve API token from the environment")
+
+ assert isinstance(self._credentials.token, str)
+ return self._credentials.token
+
+ def image_generation(
+ self,
+ prompt: str,
+ vertex_project: str,
+ vertex_location: str,
+ model: Optional[
+ str
+ ] = "imagegeneration", # vertex ai uses imagegeneration as the default model
+ client: Optional[AsyncHTTPHandler] = None,
+ optional_params: Optional[dict] = None,
+ timeout: Optional[int] = None,
+ logging_obj=None,
+ model_response=None,
+ aimg_generation=False,
+ ):
+ if aimg_generation == True:
+ response = self.aimage_generation(
+ prompt=prompt,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ model=model,
+ client=client,
+ optional_params=optional_params,
+ timeout=timeout,
+ logging_obj=logging_obj,
+ model_response=model_response,
+ )
+ return response
+
+ async def aimage_generation(
+ self,
+ prompt: str,
+ vertex_project: str,
+ vertex_location: str,
+ model_response: litellm.ImageResponse,
+ model: Optional[
+ str
+ ] = "imagegeneration", # vertex ai uses imagegeneration as the default model
+ client: Optional[AsyncHTTPHandler] = None,
+ optional_params: Optional[dict] = None,
+ timeout: Optional[int] = None,
+ logging_obj=None,
+ ):
+ response = None
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
+ else:
+ self.async_handler = client # type: ignore
+
+ # make POST request to
+ # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
+
+ """
+ Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
+ curl -X POST \
+ -H "Authorization: Bearer $(gcloud auth print-access-token)" \
+ -H "Content-Type: application/json; charset=utf-8" \
+ -d {
+ "instances": [
+ {
+ "prompt": "a cat"
+ }
+ ],
+ "parameters": {
+ "sampleCount": 1
+ }
+ } \
+ "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
+ """
+ auth_header = self._ensure_access_token()
+ optional_params = optional_params or {
+ "sampleCount": 1
+ } # default optional params
+
+ request_data = {
+ "instances": [{"prompt": prompt}],
+ "parameters": optional_params,
+ }
+
+ request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ response = await self.async_handler.post(
+ url=url,
+ headers={
+ "Content-Type": "application/json; charset=utf-8",
+ "Authorization": f"Bearer {auth_header}",
+ },
+ data=json.dumps(request_data),
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+ """
+ Vertex AI Image generation response example:
+ {
+ "predictions": [
+ {
+ "bytesBase64Encoded": "BASE64_IMG_BYTES",
+ "mimeType": "image/png"
+ },
+ {
+ "mimeType": "image/png",
+ "bytesBase64Encoded": "BASE64_IMG_BYTES"
+ }
+ ]
+ }
+ """
+
+ _json_response = response.json()
+ _predictions = _json_response["predictions"]
+
+ _response_data: List[litellm.ImageObject] = []
+ for _prediction in _predictions:
+ _bytes_base64_encoded = _prediction["bytesBase64Encoded"]
+ image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded)
+ _response_data.append(image_object)
+
+ model_response.data = _response_data
+
+ return model_response
diff --git a/litellm/main.py b/litellm/main.py
index 3429cab4d..525a39d68 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -14,7 +14,6 @@ from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
import httpx
-
import litellm
from ._logging import verbose_logger
from litellm import ( # type: ignore
@@ -73,12 +72,14 @@ from .llms import (
)
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.azure import AzureChatCompletion
+from .llms.databricks import DatabricksChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
+from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
prompt_factory,
@@ -90,6 +91,7 @@ import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict, Union, Mapping
from .caching import enable_cache, disable_cache, update_cache
+from .types.llms.openai import HttpxBinaryResponseContent
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
@@ -110,6 +112,7 @@ from litellm.utils import (
####### ENVIRONMENT VARIABLES ###################
openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion()
+databricks_chat_completions = DatabricksChatCompletion()
anthropic_chat_completions = AnthropicChatCompletion()
anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion()
@@ -118,6 +121,7 @@ huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
+vertex_chat_completion = VertexLLM()
####### COMPLETION ENDPOINTS ################
@@ -290,6 +294,7 @@ async def acompletion(
"api_version": api_version,
"api_key": api_key,
"model_list": model_list,
+ "extra_headers": extra_headers,
"acompletion": True, # assuming this is a required parameter
}
if custom_llm_provider is None:
@@ -320,12 +325,14 @@ async def acompletion(
or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama"
or custom_llm_provider == "ollama_chat"
+ or custom_llm_provider == "replicate"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"
or custom_llm_provider == "predibase"
- or (custom_llm_provider == "bedrock" and "cohere" in model)
+ or custom_llm_provider == "bedrock"
+ or custom_llm_provider == "databricks"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
@@ -367,6 +374,8 @@ async def acompletion(
async def _async_streaming(response, model, custom_llm_provider, args):
try:
print_verbose(f"received response in _async_streaming: {response}")
+ if asyncio.iscoroutine(response):
+ response = await response
async for line in response:
print_verbose(f"line in async streaming: {line}")
yield line
@@ -412,6 +421,8 @@ def mock_completion(
api_key="mock-key",
)
if isinstance(mock_response, Exception):
+ if isinstance(mock_response, openai.APIError):
+ raise mock_response
raise litellm.APIError(
status_code=500, # type: ignore
message=str(mock_response),
@@ -455,7 +466,9 @@ def mock_completion(
return model_response
- except:
+ except Exception as e:
+ if isinstance(e, openai.APIError):
+ raise e
traceback.print_exc()
raise Exception("Mock completion response failed")
@@ -481,7 +494,7 @@ def completion(
response_format: Optional[dict] = None,
seed: Optional[int] = None,
tools: Optional[List] = None,
- tool_choice: Optional[str] = None,
+ tool_choice: Optional[Union[str, dict]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
deployment_id=None,
@@ -552,7 +565,7 @@ def completion(
model_info = kwargs.get("model_info", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
fallbacks = kwargs.get("fallbacks", None)
- headers = kwargs.get("headers", None)
+ headers = kwargs.get("headers", None) or extra_headers
num_retries = kwargs.get("num_retries", None) ## deprecated
max_retries = kwargs.get("max_retries", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
@@ -667,6 +680,7 @@ def completion(
"region_name",
"allowed_model_region",
"model_config",
+ "fastest_response",
]
default_params = openai_params + litellm_params
@@ -674,20 +688,6 @@ def completion(
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
- ### TIMEOUT LOGIC ###
- timeout = timeout or kwargs.get("request_timeout", 600) or 600
- # set timeout for 10 minutes by default
-
- if (
- timeout is not None
- and isinstance(timeout, httpx.Timeout)
- and supports_httpx_timeout(custom_llm_provider) == False
- ):
- read_timeout = timeout.read or 600
- timeout = read_timeout # default 10 min timeout
- elif timeout is not None and not isinstance(timeout, httpx.Timeout):
- timeout = float(timeout) # type: ignore
-
try:
if base_url is not None:
api_base = base_url
@@ -727,6 +727,16 @@ def completion(
"aws_region_name", None
) # support region-based pricing for bedrock
+ ### TIMEOUT LOGIC ###
+ timeout = timeout or kwargs.get("request_timeout", 600) or 600
+ # set timeout for 10 minutes by default
+ if isinstance(timeout, httpx.Timeout) and not supports_httpx_timeout(
+ custom_llm_provider
+ ):
+ timeout = timeout.read or 600 # default 10 min timeout
+ elif not isinstance(timeout, httpx.Timeout):
+ timeout = float(timeout) # type: ignore
+
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None:
litellm.register_model(
@@ -860,6 +870,7 @@ def completion(
user=user,
optional_params=optional_params,
litellm_params=litellm_params,
+ custom_llm_provider=custom_llm_provider,
)
if mock_response:
return mock_completion(
@@ -1192,7 +1203,7 @@ def completion(
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
- model_response = replicate.completion(
+ model_response = replicate.completion( # type: ignore
model=model,
messages=messages,
api_base=api_base,
@@ -1205,12 +1216,10 @@ def completion(
api_key=replicate_key,
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
+ acompletion=acompletion,
)
- if "stream" in optional_params and optional_params["stream"] == True:
- # don't try to access stream object,
- model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
- if optional_params.get("stream", False) or acompletion == True:
+ if optional_params.get("stream", False) == True:
## LOGGING
logging.post_call(
input=messages,
@@ -1616,6 +1625,61 @@ def completion(
)
return response
response = model_response
+ elif custom_llm_provider == "databricks":
+ api_base = (
+ api_base # for databricks we check in get_llm_provider and pass in the api base from there
+ or litellm.api_base
+ or os.getenv("DATABRICKS_API_BASE")
+ )
+
+ # set API KEY
+ api_key = (
+ api_key
+ or litellm.api_key # for databricks we check in get_llm_provider and pass in the api key from there
+ or litellm.databricks_key
+ or get_secret("DATABRICKS_API_KEY")
+ )
+
+ headers = headers or litellm.headers
+
+ ## COMPLETION CALL
+ try:
+ response = databricks_chat_completions.completion(
+ model=model,
+ messages=messages,
+ headers=headers,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ api_key=api_key,
+ api_base=api_base,
+ acompletion=acompletion,
+ logging_obj=logging,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ timeout=timeout, # type: ignore
+ custom_prompt_dict=custom_prompt_dict,
+ client=client, # pass AsyncOpenAI, OpenAI client
+ encoding=encoding,
+ )
+ except Exception as e:
+ ## LOGGING - log the original exception returned
+ logging.post_call(
+ input=messages,
+ api_key=api_key,
+ original_response=str(e),
+ additional_args={"headers": headers},
+ )
+ raise e
+
+ if optional_params.get("stream", False):
+ ## LOGGING
+ logging.post_call(
+ input=messages,
+ api_key=api_key,
+ original_response=response,
+ additional_args={"headers": headers},
+ )
elif custom_llm_provider == "openrouter":
api_base = api_base or litellm.api_base or "https://openrouter.ai/api/v1"
@@ -1984,23 +2048,9 @@ def completion(
# boto3 reads keys from .env
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
- if "cohere" in model:
- response = bedrock_chat_completion.completion(
- model=model,
- messages=messages,
- custom_prompt_dict=litellm.custom_prompt_dict,
- model_response=model_response,
- print_verbose=print_verbose,
- optional_params=optional_params,
- litellm_params=litellm_params,
- logger_fn=logger_fn,
- encoding=encoding,
- logging_obj=logging,
- extra_headers=extra_headers,
- timeout=timeout,
- acompletion=acompletion,
- )
- else:
+ if (
+ "aws_bedrock_client" in optional_params
+ ): # use old bedrock flow for aws_bedrock_client users.
response = bedrock.completion(
model=model,
messages=messages,
@@ -2036,7 +2086,23 @@ def completion(
custom_llm_provider="bedrock",
logging_obj=logging,
)
-
+ else:
+ response = bedrock_chat_completion.completion(
+ model=model,
+ messages=messages,
+ custom_prompt_dict=custom_prompt_dict,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ encoding=encoding,
+ logging_obj=logging,
+ extra_headers=extra_headers,
+ timeout=timeout,
+ acompletion=acompletion,
+ client=client,
+ )
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
@@ -2477,6 +2543,7 @@ def batch_completion(
list: A list of completion results.
"""
args = locals()
+
batch_messages = messages
completions = []
model = model
@@ -2530,7 +2597,15 @@ def batch_completion(
completions.append(future)
# Retrieve the results from the futures
- results = [future.result() for future in completions]
+ # results = [future.result() for future in completions]
+ # return exceptions if any
+ results = []
+ for future in completions:
+ try:
+ results.append(future.result())
+ except Exception as exc:
+ results.append(exc)
+
return results
@@ -2669,7 +2744,7 @@ def batch_completion_models_all_responses(*args, **kwargs):
### EMBEDDING ENDPOINTS ####################
@client
-async def aembedding(*args, **kwargs):
+async def aembedding(*args, **kwargs) -> EmbeddingResponse:
"""
Asynchronously calls the `embedding` function with the given arguments and keyword arguments.
@@ -2714,12 +2789,13 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"
+ or custom_llm_provider == "databricks"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
- if isinstance(init_response, dict) or isinstance(
- init_response, ModelResponse
- ): ## CACHING SCENARIO
+ if isinstance(init_response, dict):
+ response = EmbeddingResponse(**init_response)
+ elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
@@ -2759,7 +2835,7 @@ def embedding(
litellm_logging_obj=None,
logger_fn=None,
**kwargs,
-):
+) -> EmbeddingResponse:
"""
Embedding function that calls an API to generate embeddings for the given input.
@@ -2907,7 +2983,7 @@ def embedding(
)
try:
response = None
- logging = litellm_logging_obj
+ logging: Logging = litellm_logging_obj # type: ignore
logging.update_environment_variables(
model=model,
user=user,
@@ -2997,6 +3073,32 @@ def embedding(
client=client,
aembedding=aembedding,
)
+ elif custom_llm_provider == "databricks":
+ api_base = (
+ api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE")
+ ) # type: ignore
+
+ # set API KEY
+ api_key = (
+ api_key
+ or litellm.api_key
+ or litellm.databricks_key
+ or get_secret("DATABRICKS_API_KEY")
+ ) # type: ignore
+
+ ## EMBEDDING CALL
+ response = databricks_chat_completions.embedding(
+ model=model,
+ input=input,
+ api_base=api_base,
+ api_key=api_key,
+ logging_obj=logging,
+ timeout=timeout,
+ model_response=EmbeddingResponse(),
+ optional_params=optional_params,
+ client=client,
+ aembedding=aembedding,
+ )
elif custom_llm_provider == "cohere":
cohere_key = (
api_key
@@ -3856,6 +3958,36 @@ def image_generation(
model_response=model_response,
aimg_generation=aimg_generation,
)
+ elif custom_llm_provider == "vertex_ai":
+ vertex_ai_project = (
+ optional_params.pop("vertex_project", None)
+ or optional_params.pop("vertex_ai_project", None)
+ or litellm.vertex_project
+ or get_secret("VERTEXAI_PROJECT")
+ )
+ vertex_ai_location = (
+ optional_params.pop("vertex_location", None)
+ or optional_params.pop("vertex_ai_location", None)
+ or litellm.vertex_location
+ or get_secret("VERTEXAI_LOCATION")
+ )
+ vertex_credentials = (
+ optional_params.pop("vertex_credentials", None)
+ or optional_params.pop("vertex_ai_credentials", None)
+ or get_secret("VERTEXAI_CREDENTIALS")
+ )
+ model_response = vertex_chat_completion.image_generation(
+ model=model,
+ prompt=prompt,
+ timeout=timeout,
+ logging_obj=litellm_logging_obj,
+ optional_params=optional_params,
+ model_response=model_response,
+ vertex_project=vertex_ai_project,
+ vertex_location=vertex_ai_location,
+ aimg_generation=aimg_generation,
+ )
+
return model_response
except Exception as e:
## Map to OpenAI Exception
@@ -3999,6 +4131,24 @@ def transcription(
max_retries=max_retries,
)
elif custom_llm_provider == "openai":
+ api_base = (
+ api_base
+ or litellm.api_base
+ or get_secret("OPENAI_API_BASE")
+ or "https://api.openai.com/v1"
+ ) # type: ignore
+ openai.organization = (
+ litellm.organization
+ or get_secret("OPENAI_ORGANIZATION")
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
+ )
+ # set API KEY
+ api_key = (
+ api_key
+ or litellm.api_key
+ or litellm.openai_key
+ or get_secret("OPENAI_API_KEY")
+ ) # type: ignore
response = openai_chat_completions.audio_transcriptions(
model=model,
audio_file=file,
@@ -4008,6 +4158,139 @@ def transcription(
timeout=timeout,
logging_obj=litellm_logging_obj,
max_retries=max_retries,
+ api_base=api_base,
+ api_key=api_key,
+ )
+ return response
+
+
+@client
+async def aspeech(*args, **kwargs) -> HttpxBinaryResponseContent:
+ """
+ Calls openai tts endpoints.
+ """
+ loop = asyncio.get_event_loop()
+ model = args[0] if len(args) > 0 else kwargs["model"]
+ ### PASS ARGS TO Image Generation ###
+ kwargs["aspeech"] = True
+ custom_llm_provider = kwargs.get("custom_llm_provider", None)
+ try:
+ # Use a partial function to pass your keyword arguments
+ func = partial(speech, *args, **kwargs)
+
+ # Add the context to the function
+ ctx = contextvars.copy_context()
+ func_with_context = partial(ctx.run, func)
+
+ _, custom_llm_provider, _, _ = get_llm_provider(
+ model=model, api_base=kwargs.get("api_base", None)
+ )
+
+ # Await normally
+ init_response = await loop.run_in_executor(None, func_with_context)
+ if asyncio.iscoroutine(init_response):
+ response = await init_response
+ else:
+ # Call the synchronous function using run_in_executor
+ response = await loop.run_in_executor(None, func_with_context)
+ return response # type: ignore
+ except Exception as e:
+ custom_llm_provider = custom_llm_provider or "openai"
+ raise exception_type(
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ original_exception=e,
+ completion_kwargs=args,
+ extra_kwargs=kwargs,
+ )
+
+
+@client
+def speech(
+ model: str,
+ input: str,
+ voice: str,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ organization: Optional[str] = None,
+ project: Optional[str] = None,
+ max_retries: Optional[int] = None,
+ metadata: Optional[dict] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ response_format: Optional[str] = None,
+ speed: Optional[int] = None,
+ client=None,
+ headers: Optional[dict] = None,
+ custom_llm_provider: Optional[str] = None,
+ aspeech: Optional[bool] = None,
+ **kwargs,
+) -> HttpxBinaryResponseContent:
+
+ model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
+
+ optional_params = {}
+ if response_format is not None:
+ optional_params["response_format"] = response_format
+ if speed is not None:
+ optional_params["speed"] = speed # type: ignore
+
+ if timeout is None:
+ timeout = litellm.request_timeout
+
+ if max_retries is None:
+ max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
+ response: Optional[HttpxBinaryResponseContent] = None
+ if custom_llm_provider == "openai":
+ api_base = (
+ api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
+ or litellm.api_base
+ or get_secret("OPENAI_API_BASE")
+ or "https://api.openai.com/v1"
+ ) # type: ignore
+ # set API KEY
+ api_key = (
+ api_key
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
+ or litellm.openai_key
+ or get_secret("OPENAI_API_KEY")
+ ) # type: ignore
+
+ organization = (
+ organization
+ or litellm.organization
+ or get_secret("OPENAI_ORGANIZATION")
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
+ ) # type: ignore
+
+ project = (
+ project
+ or litellm.project
+ or get_secret("OPENAI_PROJECT")
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
+ ) # type: ignore
+
+ headers = headers or litellm.headers
+
+ response = openai_chat_completions.audio_speech(
+ model=model,
+ input=input,
+ voice=voice,
+ optional_params=optional_params,
+ api_key=api_key,
+ api_base=api_base,
+ organization=organization,
+ project=project,
+ max_retries=max_retries,
+ timeout=timeout,
+ client=client, # pass AsyncOpenAI, OpenAI client
+ aspeech=aspeech,
+ )
+
+ if response is None:
+ raise Exception(
+ "Unable to map the custom llm provider={} to a known provider={}.".format(
+ custom_llm_provider, litellm.provider_list
+ )
)
return response
diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json
index ff9194578..f090c5a3f 100644
--- a/litellm/model_prices_and_context_window_backup.json
+++ b/litellm/model_prices_and_context_window_backup.json
@@ -234,6 +234,24 @@
"litellm_provider": "openai",
"mode": "chat"
},
+ "ft:davinci-002": {
+ "max_tokens": 16384,
+ "max_input_tokens": 16384,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.000002,
+ "output_cost_per_token": 0.000002,
+ "litellm_provider": "text-completion-openai",
+ "mode": "completion"
+ },
+ "ft:babbage-002": {
+ "max_tokens": 16384,
+ "max_input_tokens": 16384,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.0000004,
+ "output_cost_per_token": 0.0000004,
+ "litellm_provider": "text-completion-openai",
+ "mode": "completion"
+ },
"text-embedding-3-large": {
"max_tokens": 8191,
"max_input_tokens": 8191,
@@ -500,8 +518,8 @@
"max_tokens": 4096,
"max_input_tokens": 4097,
"max_output_tokens": 4096,
- "input_cost_per_token": 0.0000015,
- "output_cost_per_token": 0.000002,
+ "input_cost_per_token": 0.0000005,
+ "output_cost_per_token": 0.0000015,
"litellm_provider": "azure",
"mode": "chat",
"supports_function_calling": true
@@ -1247,13 +1265,19 @@
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
- "input_cost_per_token": 0.0000015,
- "output_cost_per_token": 0.0000075,
+ "input_cost_per_token": 0.000015,
+ "output_cost_per_token": 0.000075,
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true
},
+ "vertex_ai/imagegeneration@006": {
+ "cost_per_image": 0.020,
+ "litellm_provider": "vertex_ai-image-models",
+ "mode": "image_generation",
+ "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
+ },
"textembedding-gecko": {
"max_tokens": 3072,
"max_input_tokens": 3072,
@@ -1385,6 +1409,24 @@
"mode": "completion",
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
+ "gemini/gemini-1.5-flash-latest": {
+ "max_tokens": 8192,
+ "max_input_tokens": 1000000,
+ "max_output_tokens": 8192,
+ "max_images_per_prompt": 3000,
+ "max_videos_per_prompt": 10,
+ "max_video_length": 1,
+ "max_audio_length_hours": 8.4,
+ "max_audio_per_prompt": 1,
+ "max_pdf_size_mb": 30,
+ "input_cost_per_token": 0,
+ "output_cost_per_token": 0,
+ "litellm_provider": "vertex_ai-language-models",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_vision": true,
+ "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
+ },
"gemini/gemini-pro": {
"max_tokens": 8192,
"max_input_tokens": 32760,
@@ -1563,36 +1605,36 @@
"mode": "chat"
},
"replicate/meta/llama-3-70b": {
- "max_tokens": 4096,
- "max_input_tokens": 4096,
- "max_output_tokens": 4096,
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
"input_cost_per_token": 0.00000065,
"output_cost_per_token": 0.00000275,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-3-70b-instruct": {
- "max_tokens": 4096,
- "max_input_tokens": 4096,
- "max_output_tokens": 4096,
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
"input_cost_per_token": 0.00000065,
"output_cost_per_token": 0.00000275,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-3-8b": {
- "max_tokens": 4096,
- "max_input_tokens": 4096,
- "max_output_tokens": 4096,
+ "max_tokens": 8086,
+ "max_input_tokens": 8086,
+ "max_output_tokens": 8086,
"input_cost_per_token": 0.00000005,
"output_cost_per_token": 0.00000025,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-3-8b-instruct": {
- "max_tokens": 4096,
- "max_input_tokens": 4096,
- "max_output_tokens": 4096,
+ "max_tokens": 8086,
+ "max_input_tokens": 8086,
+ "max_output_tokens": 8086,
"input_cost_per_token": 0.00000005,
"output_cost_per_token": 0.00000025,
"litellm_provider": "replicate",
@@ -1856,7 +1898,7 @@
"mode": "chat"
},
"openrouter/meta-llama/codellama-34b-instruct": {
- "max_tokens": 8096,
+ "max_tokens": 8192,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000005,
"litellm_provider": "openrouter",
@@ -3348,9 +3390,10 @@
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat",
- "supports_function_calling": true
+ "supports_function_calling": true,
+ "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mistral-7B-Instruct-v0.1"
},
- "anyscale/Mixtral-8x7B-Instruct-v0.1": {
+ "anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1": {
"max_tokens": 16384,
"max_input_tokens": 16384,
"max_output_tokens": 16384,
@@ -3358,7 +3401,19 @@
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat",
- "supports_function_calling": true
+ "supports_function_calling": true,
+ "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mixtral-8x7B-Instruct-v0.1"
+ },
+ "anyscale/mistralai/Mixtral-8x22B-Instruct-v0.1": {
+ "max_tokens": 65536,
+ "max_input_tokens": 65536,
+ "max_output_tokens": 65536,
+ "input_cost_per_token": 0.00000090,
+ "output_cost_per_token": 0.00000090,
+ "litellm_provider": "anyscale",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mixtral-8x22B-Instruct-v0.1"
},
"anyscale/HuggingFaceH4/zephyr-7b-beta": {
"max_tokens": 16384,
@@ -3369,6 +3424,16 @@
"litellm_provider": "anyscale",
"mode": "chat"
},
+ "anyscale/google/gemma-7b-it": {
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.00000015,
+ "output_cost_per_token": 0.00000015,
+ "litellm_provider": "anyscale",
+ "mode": "chat",
+ "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/google-gemma-7b-it"
+ },
"anyscale/meta-llama/Llama-2-7b-chat-hf": {
"max_tokens": 4096,
"max_input_tokens": 4096,
@@ -3405,6 +3470,36 @@
"litellm_provider": "anyscale",
"mode": "chat"
},
+ "anyscale/codellama/CodeLlama-70b-Instruct-hf": {
+ "max_tokens": 4096,
+ "max_input_tokens": 4096,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000001,
+ "litellm_provider": "anyscale",
+ "mode": "chat",
+ "source" : "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/codellama-CodeLlama-70b-Instruct-hf"
+ },
+ "anyscale/meta-llama/Meta-Llama-3-8B-Instruct": {
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.00000015,
+ "output_cost_per_token": 0.00000015,
+ "litellm_provider": "anyscale",
+ "mode": "chat",
+ "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/meta-llama-Meta-Llama-3-8B-Instruct"
+ },
+ "anyscale/meta-llama/Meta-Llama-3-70B-Instruct": {
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.00000100,
+ "output_cost_per_token": 0.00000100,
+ "litellm_provider": "anyscale",
+ "mode": "chat",
+ "source" : "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/meta-llama-Meta-Llama-3-70B-Instruct"
+ },
"cloudflare/@cf/meta/llama-2-7b-chat-fp16": {
"max_tokens": 3072,
"max_input_tokens": 3072,
@@ -3496,6 +3591,76 @@
"output_cost_per_token": 0.000000,
"litellm_provider": "voyage",
"mode": "embedding"
- }
+ },
+ "databricks/databricks-dbrx-instruct": {
+ "max_tokens": 32768,
+ "max_input_tokens": 32768,
+ "max_output_tokens": 32768,
+ "input_cost_per_token": 0.00000075,
+ "output_cost_per_token": 0.00000225,
+ "litellm_provider": "databricks",
+ "mode": "chat",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ },
+ "databricks/databricks-meta-llama-3-70b-instruct": {
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000003,
+ "litellm_provider": "databricks",
+ "mode": "chat",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ },
+ "databricks/databricks-llama-2-70b-chat": {
+ "max_tokens": 4096,
+ "max_input_tokens": 4096,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.0000005,
+ "output_cost_per_token": 0.0000015,
+ "litellm_provider": "databricks",
+ "mode": "chat",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ },
+ "databricks/databricks-mixtral-8x7b-instruct": {
+ "max_tokens": 4096,
+ "max_input_tokens": 4096,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.0000005,
+ "output_cost_per_token": 0.000001,
+ "litellm_provider": "databricks",
+ "mode": "chat",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ },
+ "databricks/databricks-mpt-30b-instruct": {
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.000001,
+ "output_cost_per_token": 0.000001,
+ "litellm_provider": "databricks",
+ "mode": "chat",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ },
+ "databricks/databricks-mpt-7b-instruct": {
+ "max_tokens": 8192,
+ "max_input_tokens": 8192,
+ "max_output_tokens": 8192,
+ "input_cost_per_token": 0.0000005,
+ "output_cost_per_token": 0.0000005,
+ "litellm_provider": "databricks",
+ "mode": "chat",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ },
+ "databricks/databricks-bge-large-en": {
+ "max_tokens": 512,
+ "max_input_tokens": 512,
+ "output_vector_size": 1024,
+ "input_cost_per_token": 0.0000001,
+ "output_cost_per_token": 0.0,
+ "litellm_provider": "databricks",
+ "mode": "embedding",
+ "source": "https://www.databricks.com/product/pricing/foundation-model-serving"
+ }
}
diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html
index b0b75d094..a53f906ff 100644
--- a/litellm/proxy/_experimental/out/404.html
+++ b/litellm/proxy/_experimental/out/404.html
@@ -1 +1 @@
-404: This page could not be found.LiteLLM Dashboard
404
This page could not be found.
\ No newline at end of file
+404: This page could not be found.LiteLLM Dashboard