Merge branch 'main' into litellm_security_fix

This commit is contained in:
Ishaan Jaff 2024-06-07 16:52:25 -07:00 committed by GitHub
commit 92841dfe1b
31 changed files with 2394 additions and 5332 deletions

1
.gitignore vendored
View file

@ -59,3 +59,4 @@ myenv/*
litellm/proxy/_experimental/out/404/index.html
litellm/proxy/_experimental/out/model_hub/index.html
litellm/proxy/_experimental/out/onboarding/index.html
litellm/tests/log.txt

View file

@ -62,6 +62,23 @@ curl -X GET 'http://localhost:4000/health/services?service=slack' \
-H 'Authorization: Bearer sk-1234'
```
## Advanced - Redacting Messages from Alerts
By default alerts show the `messages/input` passed to the LLM. If you want to redact this from slack alerting set the following setting on your config
```shell
general_settings:
alerting: ["slack"]
alert_types: ["spend_reports"]
litellm_settings:
redact_messages_in_exceptions: True
```
## Advanced - Opting into specific alert types
Set `alert_types` if you want to Opt into only specific alert types

View file

@ -5,7 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*
### INIT VARIABLES ###
import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any, Literal
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.caching import Cache
from litellm._logging import (
set_verbose,
@ -60,6 +60,7 @@ _async_failure_callback: List[Callable] = (
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False
redact_messages_in_exceptions: Optional[bool] = False
store_audit_logs = False # Enterprise feature, allow users to see audit logs
## end of callbacks #############
@ -233,6 +234,7 @@ max_end_user_budget: Optional[float] = None
#### RELIABILITY ####
request_timeout: float = 6000
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
module_level_client = HTTPHandler(timeout=request_timeout)
num_retries: Optional[int] = None # per model endpoint
default_fallbacks: Optional[List] = None
fallbacks: Optional[List] = None
@ -766,7 +768,7 @@ from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig
from .llms.bedrock import (
AmazonTitanConfig,
AmazonAI21Config,

View file

@ -1,10 +1,18 @@
import litellm, traceback
from datetime import datetime
import litellm
from litellm.proxy._types import UserAPIKeyAuth
from .types.services import ServiceTypes, ServiceLoggerPayload
from .integrations.prometheus_services import PrometheusServicesLogger
from .integrations.custom_logger import CustomLogger
from datetime import timedelta
from typing import Union
from typing import Union, Optional, TYPE_CHECKING, Any
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class ServiceLogging(CustomLogger):
@ -40,7 +48,13 @@ class ServiceLogging(CustomLogger):
self.mock_testing_sync_failure_hook += 1
async def async_service_success_hook(
self, service: ServiceTypes, duration: float, call_type: str
self,
service: ServiceTypes,
call_type: str,
duration: float,
parent_otel_span: Optional[Span] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
):
"""
- For counting if the redis, postgres call is successful
@ -61,6 +75,16 @@ class ServiceLogging(CustomLogger):
payload=payload
)
from litellm.proxy.proxy_server import open_telemetry_logger
if parent_otel_span is not None and open_telemetry_logger is not None:
await open_telemetry_logger.async_service_success_hook(
payload=payload,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
)
async def async_service_failure_hook(
self,
service: ServiceTypes,

View file

@ -1,9 +1,21 @@
import os
from typing import Optional
from dataclasses import dataclass
from datetime import datetime
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_logger
from litellm.types.services import ServiceLoggerPayload
from typing import Union, Optional, TYPE_CHECKING, Any
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
Span = _Span
UserAPIKeyAuth = _UserAPIKeyAuth
else:
Span = Any
UserAPIKeyAuth = Any
LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm")
@ -77,6 +89,56 @@ class OpenTelemetry(CustomLogger):
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
self._handle_failure(kwargs, response_obj, start_time, end_time)
async def async_service_success_hook(
self,
payload: ServiceLoggerPayload,
parent_otel_span: Optional[Span] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
):
from opentelemetry import trace
from datetime import datetime
from opentelemetry.trace import Status, StatusCode
if parent_otel_span is not None:
_span_name = payload.service
service_logging_span = self.tracer.start_span(
name=_span_name,
context=trace.set_span_in_context(parent_otel_span),
start_time=self._to_ns(start_time),
)
service_logging_span.set_attribute(key="call_type", value=payload.call_type)
service_logging_span.set_attribute(
key="service", value=payload.service.value
)
service_logging_span.set_status(Status(StatusCode.OK))
service_logging_span.end(end_time=self._to_ns(end_time))
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
from opentelemetry.trace import Status, StatusCode
from opentelemetry import trace
parent_otel_span = user_api_key_dict.parent_otel_span
if parent_otel_span is not None:
parent_otel_span.set_status(Status(StatusCode.ERROR))
_span_name = "Failed Proxy Server Request"
# Exception Logging Child Span
exception_logging_span = self.tracer.start_span(
name=_span_name,
context=trace.set_span_in_context(parent_otel_span),
)
exception_logging_span.set_attribute(
key="exception", value=str(original_exception)
)
exception_logging_span.set_status(Status(StatusCode.ERROR))
exception_logging_span.end(end_time=self._to_ns(datetime.now()))
# End Parent OTEL Sspan
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
from opentelemetry.trace import Status, StatusCode
@ -85,15 +147,18 @@ class OpenTelemetry(CustomLogger):
kwargs,
self.config,
)
_parent_context, parent_otel_span = self._get_span_context(kwargs)
span = self.tracer.start_span(
name=self._get_span_name(kwargs),
start_time=self._to_ns(start_time),
context=self._get_span_context(kwargs),
context=_parent_context,
)
span.set_status(Status(StatusCode.OK))
self.set_attributes(span, kwargs, response_obj)
span.end(end_time=self._to_ns(end_time))
if parent_otel_span is not None:
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
from opentelemetry.trace import Status, StatusCode
@ -122,17 +187,28 @@ class OpenTelemetry(CustomLogger):
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
from opentelemetry import trace
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
headers = proxy_server_request.get("headers", {}) or {}
traceparent = headers.get("traceparent", None)
_metadata = litellm_params.get("metadata", {})
parent_otel_span = _metadata.get("litellm_parent_otel_span", None)
"""
Two way to use parents in opentelemetry
- using the traceparent header
- using the parent_otel_span in the [metadata][parent_otel_span]
"""
if parent_otel_span is not None:
return trace.set_span_in_context(parent_otel_span), parent_otel_span
if traceparent is None:
return None
return None, None
else:
carrier = {"traceparent": traceparent}
return TraceContextTextMapPropagator().extract(carrier=carrier)
return TraceContextTextMapPropagator().extract(carrier=carrier), None
def _get_span_processor(self):
from opentelemetry.sdk.trace.export import (

View file

@ -326,8 +326,8 @@ class SlackAlerting(CustomLogger):
end_time=end_time,
)
)
if litellm.turn_off_message_logging:
messages = "Message not logged. `litellm.turn_off_message_logging=True`."
if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions:
messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
if time_difference_float > self.alerting_threshold:
@ -567,9 +567,12 @@ class SlackAlerting(CustomLogger):
except:
messages = ""
if litellm.turn_off_message_logging:
if (
litellm.turn_off_message_logging
or litellm.redact_messages_in_exceptions
):
messages = (
"Message not logged. `litellm.turn_off_message_logging=True`."
"Message not logged. litellm.redact_messages_in_exceptions=True"
)
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
else:

View file

@ -38,6 +38,8 @@ from .prompt_templates.factory import (
extract_between_tags,
parse_xml_params,
contains_tag,
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
@ -45,6 +47,11 @@ import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
from litellm.types.llms.bedrock import *
import urllib.parse
from litellm.types.llms.openai import (
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
)
class AmazonCohereChatConfig:
@ -118,6 +125,8 @@ class AmazonCohereChatConfig:
"presence_penalty",
"seed",
"stop",
"tools",
"tool_choice",
]
def map_openai_params(
@ -169,7 +178,38 @@ async def make_call(
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
if client is None:
client = HTTPHandler() # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.read())
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
@ -1000,12 +1040,12 @@ class BedrockLLM(BaseLLM):
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
client = client # type: ignore
try:
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
response = await client.post(api_base, headers=headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
@ -1069,6 +1109,738 @@ class BedrockLLM(BaseLLM):
return super().embedding(*args, **kwargs)
class AmazonConverseConfig:
"""
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
"""
maxTokens: Optional[int]
stopSequences: Optional[List[str]]
temperature: Optional[int]
topP: Optional[int]
def __init__(
self,
maxTokens: Optional[int] = None,
stopSequences: Optional[List[str]] = None,
temperature: Optional[int] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List[str]:
supported_params = [
"max_tokens",
"stream",
"stream_options",
"stop",
"temperature",
"top_p",
"extra_headers",
]
if (
model.startswith("anthropic")
or model.startswith("mistral")
or model.startswith("cohere")
):
supported_params.append("tools")
if model.startswith("anthropic") or model.startswith("mistral"):
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
supported_params.append("tool_choice")
return supported_params
def map_tool_choice_values(
self, model: str, tool_choice: Union[str, dict], drop_params: bool
) -> Optional[ToolChoiceValuesBlock]:
if tool_choice == "none":
if litellm.drop_params is True or drop_params is True:
return None
else:
raise litellm.utils.UnsupportedParamsError(
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
tool_choice
),
status_code=400,
)
elif tool_choice == "required":
return ToolChoiceValuesBlock(any={})
elif tool_choice == "auto":
return ToolChoiceValuesBlock(auto={})
elif isinstance(tool_choice, dict):
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
specific_tool = SpecificToolChoiceBlock(
name=tool_choice.get("function", {}).get("name", "")
)
return ToolChoiceValuesBlock(tool=specific_tool)
else:
raise litellm.utils.UnsupportedParamsError(
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
tool_choice
),
status_code=400,
)
def get_supported_image_types(self) -> List[str]:
return ["png", "jpeg", "gif", "webp"]
def map_openai_params(
self,
model: str,
non_default_params: dict,
optional_params: dict,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["maxTokens"] = value
if param == "stream":
optional_params["stream"] = value
if param == "stop":
if isinstance(value, str):
value = [value]
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["topP"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice_value = self.map_tool_choice_values(
model=model, tool_choice=value, drop_params=drop_params # type: ignore
)
if _tool_choice_value is not None:
optional_params["tool_choice"] = _tool_choice_value
return optional_params
class BedrockConverseLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> Union[ModelResponse, CustomStreamWrapper]:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
except Exception as e:
raise BedrockError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
response.text, str(e)
),
status_code=422,
)
"""
Bedrock Response Object has optional message block
completion_response["output"].get("message", None)
A message block looks like this (Example 1):
"output": {
"message": {
"role": "assistant",
"content": [
{
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
}
]
}
},
(Example 2):
"output": {
"message": {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
"name": "top_song",
"input": {
"sign": "WZPZ"
}
}
}
]
}
}
"""
message: Optional[MessageBlock] = completion_response["output"]["message"]
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
content_str = ""
tools: List[ChatCompletionToolCallChunk] = []
if message is not None:
for content in message["content"]:
"""
- Content is either a tool response or text
"""
if "text" in content:
content_str += content["text"]
if "toolUse" in content:
_function_chunk = ChatCompletionToolCallFunctionChunk(
name=content["toolUse"]["name"],
arguments=json.dumps(content["toolUse"]["input"]),
)
_tool_response_chunk = ChatCompletionToolCallChunk(
id=content["toolUse"]["toolUseId"],
type="function",
function=_function_chunk,
)
tools.append(_tool_response_chunk)
chat_completion_message["content"] = content_str
chat_completion_message["tool_calls"] = tools
## CALCULATING USAGE - bedrock returns usage in the headers
input_tokens = completion_response["usage"]["inputTokens"]
output_tokens = completion_response["usage"]["outputTokens"]
total_tokens = completion_response["usage"]["totalTokens"]
model_response.choices = [
litellm.Choices(
finish_reason=map_finish_reason(completion_response["stopReason"]),
index=0,
message=litellm.Message(**chat_completion_message),
)
]
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=total_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def encode_model_id(self, model_id: str) -> str:
"""
Double encode the model ID to ensure it matches the expected double-encoded format.
Args:
model_id (str): The model ID to encode.
Returns:
str: The double-encoded model ID.
"""
return urllib.parse.quote(model_id, safe="")
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
import boto3
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
) = params_to_check
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
session = boto3.Session(
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=aws_region_name,
)
return session.get_credentials()
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
# Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"]
from botocore.credentials import Credentials
credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"],
token=sts_credentials["SessionToken"],
)
return credentials
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
return session.get_credentials()
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper:
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
client=client,
api_base=api_base,
headers=headers,
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def completion(
self,
model: str,
messages: list,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
):
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## SETUP ##
stream = optional_params.pop("stream", None)
modelId = optional_params.pop("model_id", None)
if modelId is not None:
modelId = self.encode_model_id(model_id=modelId)
else:
modelId = model
provider = model.split(".")[0]
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
)
### SET RUNTIME ENDPOINT ###
endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
# Separate system prompt from rest of message
system_prompt_indices = []
system_content_blocks: List[SystemContentBlock] = []
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block = SystemContentBlock(text=message["content"])
system_content_blocks.append(_system_content_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
inference_params = copy.deepcopy(optional_params)
additional_request_keys = []
additional_request_params = {}
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
supported_tool_call_params = ["tools", "tool_choice"]
## TRANSFORMATION ##
# send all model-specific params in 'additional_request_params'
for k, v in inference_params.items():
if (
k not in supported_converse_params
and k not in supported_tool_call_params
):
additional_request_params[k] = v
additional_request_keys.append(k)
for key in additional_request_keys:
inference_params.pop(key, None)
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages
)
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", [])
)
bedrock_tool_config: Optional[ToolConfigBlock] = None
if len(bedrock_tools) > 0:
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
"tool_choice", None
)
bedrock_tool_config = ToolConfigBlock(
tools=bedrock_tools,
)
if tool_choice_values is not None:
bedrock_tool_config["toolChoice"] = tool_choice_values
_data: RequestObject = {
"messages": bedrock_messages,
"additionalModelRequestFields": additional_request_params,
"system": system_content_blocks,
"inferenceConfig": InferenceConfig(**inference_params),
}
if bedrock_tool_config is not None:
_data["toolConfig"] = bedrock_tool_config
data = json.dumps(_data)
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
prepped = request.prepare()
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"headers": prepped.headers,
},
)
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream is True and provider != "ai21":
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
if (stream is not None and stream is True) and provider != "ai21":
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_sync_call,
client=None,
api_base=prepped.url,
headers=prepped.headers, # type: ignore
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
### COMPLETION
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = HTTPHandler(**_params) # type: ignore
else:
client = client
try:
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
def get_response_stream_shape():
from botocore.model import ServiceModel
from botocore.loaders import Loader
@ -1086,6 +1858,31 @@ class AWSEventStreamDecoder:
self.model = model
self.parser = EventStreamJSONParser()
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = ""
tool_str = ""
is_finished = False
finish_reason = ""
usage: Optional[ConverseTokenUsageBlock] = None
if "delta" in chunk_data:
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
if "text" in delta_obj:
text = delta_obj["text"]
elif "toolUse" in delta_obj:
tool_str = delta_obj["toolUse"]["input"]
elif "stopReason" in chunk_data:
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
elif "usage" in chunk_data:
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
response = GenericStreamingChunk(
text=text,
tool_str=tool_str,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
)
return response
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = ""
is_finished = False
@ -1098,19 +1895,8 @@ class AWSEventStreamDecoder:
is_finished = True
finish_reason = "stop"
######## bedrock.anthropic mappings ###############
elif "completion" in chunk_data: # not claude-3
text = chunk_data["completion"] # bedrock.anthropic
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
elif "delta" in chunk_data:
if chunk_data["delta"].get("text", None) is not None:
text = chunk_data["delta"]["text"]
stop_reason = chunk_data["delta"].get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
return self.converse_chunk_parser(chunk_data=chunk_data)
######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data:
if (
@ -1137,11 +1923,11 @@ class AWSEventStreamDecoder:
is_finished = True
finish_reason = chunk_data["completionReason"]
return GenericStreamingChunk(
**{
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
text=text,
is_finished=is_finished,
finish_reason=finish_reason,
tool_str="",
usage=None,
)
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
@ -1178,9 +1964,14 @@ class AWSEventStreamDecoder:
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")
if "chunk" in parsed_response:
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]
else:
chunk = response_dict.get("body")
if not chunk:
return None
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]
return chunk.decode() # type: ignore[no-any-return]

View file

@ -156,12 +156,13 @@ class HTTPHandler:
self,
url: str,
data: Optional[Union[dict, str]] = None,
json: Optional[Union[dict, str]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
req = self.client.build_request(
"POST", url, data=data, params=params, headers=headers # type: ignore
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
)
response = self.client.send(req, stream=stream)
return response

View file

@ -3,14 +3,7 @@ import requests, traceback
import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, meta, BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment
from typing import (
Any,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
)
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
import litellm
import litellm.types
from litellm.types.completion import (
@ -24,7 +17,7 @@ from litellm.types.completion import (
import litellm.types.llms
from litellm.types.llms.anthropic import *
import uuid
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
import litellm.types.llms.vertex_ai
@ -1460,9 +1453,7 @@ def _load_image_from_url(image_url):
try:
from PIL import Image
except:
raise Exception(
"gemini image conversion failed please run `pip install Pillow`"
)
raise Exception("image conversion failed please run `pip install Pillow`")
from io import BytesIO
try:
@ -1613,6 +1604,380 @@ def azure_text_pt(messages: list):
return prompt
###### AMAZON BEDROCK #######
from litellm.types.llms.bedrock import (
ToolResultContentBlock as BedrockToolResultContentBlock,
ToolResultBlock as BedrockToolResultBlock,
ToolConfigBlock as BedrockToolConfigBlock,
ToolUseBlock as BedrockToolUseBlock,
ImageSourceBlock as BedrockImageSourceBlock,
ImageBlock as BedrockImageBlock,
ContentBlock as BedrockContentBlock,
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
ToolSpecBlock as BedrockToolSpecBlock,
ToolBlock as BedrockToolBlock,
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
)
def get_image_details(image_url) -> Tuple[str, str]:
try:
import base64
# Send a GET request to the image URL
response = requests.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors
# Check the response's content type to ensure it is an image
content_type = response.headers.get("content-type")
if not content_type or "image" not in content_type:
raise ValueError(
f"URL does not point to a valid image (content-type: {content_type})"
)
# Convert the image content to base64 bytes
base64_bytes = base64.b64encode(response.content).decode("utf-8")
# Get mime-type
mime_type = content_type.split("/")[
1
] # Extract mime-type from content-type header
return base64_bytes, mime_type
except requests.RequestException as e:
raise Exception(f"Request failed: {e}")
except Exception as e:
raise e
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
if "base64" in image_url:
# Case 1: Images with base64 encoding
import base64, re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = image_url.split(",")
# read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match:
mime_type = mime_type_match.group(1)
image_format = mime_type.split("/")[1]
else:
mime_type = "image/jpeg"
image_format = "jpeg"
_blob = BedrockImageSourceBlock(bytes=img_without_base_64)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
if image_format in supported_image_formats:
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
)
)
elif "https:/" in image_url:
# Case 2: Images with direct links
image_bytes, image_format = get_image_details(image_url)
_blob = BedrockImageSourceBlock(bytes=image_bytes)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
if image_format in supported_image_formats:
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
)
)
else:
raise ValueError(
"Unsupported image type. Expected either image url or base64 encoded string - \
e.g. 'data:image/jpeg;base64,<base64-encoded-string>'"
)
def _convert_to_bedrock_tool_call_invoke(
tool_calls: list,
) -> List[BedrockContentBlock]:
"""
OpenAI tool invokes:
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
}
}
]
},
"""
"""
Bedrock tool invokes:
[
{
"role": "assistant",
"toolUse": {
"input": {"location": "Boston, MA", ..},
"name": "get_current_weather",
"toolUseId": "call_abc123"
}
}
]
"""
"""
- json.loads argument
- extract name
- extract id
"""
try:
_parts_list: List[BedrockContentBlock] = []
for tool in tool_calls:
if "function" in tool:
id = tool["id"]
name = tool["function"].get("name", "")
arguments = tool["function"].get("arguments", "")
arguments_dict = json.loads(arguments)
bedrock_tool = BedrockToolUseBlock(
input=arguments_dict, name=name, toolUseId=id
)
bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool)
_parts_list.append(bedrock_content_block)
return _parts_list
except Exception as e:
raise Exception(
"Unable to convert openai tool calls={} to bedrock tool calls. Received error={}".format(
tool_calls, str(e)
)
)
def _convert_to_bedrock_tool_call_result(
message: dict,
) -> BedrockMessageBlock:
"""
OpenAI message with a tool result looks like:
{
"tool_call_id": "tool_1",
"role": "tool",
"name": "get_current_weather",
"content": "function result goes here",
},
OpenAI message with a function call result looks like:
{
"role": "function",
"name": "get_current_weather",
"content": "function result goes here",
}
"""
"""
Bedrock result looks like this:
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": "tooluse_kZJMlvQmRJ6eAyJE5GIl7Q",
"content": [
{
"json": {
"song": "Elemental Hotel",
"artist": "8 Storey Hike"
}
}
]
}
}
]
}
"""
"""
-
"""
content = message.get("content", "")
name = message.get("name", "")
id = message.get("tool_call_id", str(uuid.uuid4()))
tool_result_content_block = BedrockToolResultContentBlock(text=content)
tool_result = BedrockToolResultBlock(
content=[tool_result_content_block],
toolUseId=id,
)
content_block = BedrockContentBlock(toolResult=tool_result)
return BedrockMessageBlock(role="user", content=[content_block])
def _bedrock_converse_messages_pt(messages: List) -> List[BedrockMessageBlock]:
"""
Converts given messages from OpenAI format to Bedrock format
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
- Please ensure that function response turn comes immediately after a function call turn
"""
contents: List[BedrockMessageBlock] = []
msg_i = 0
while msg_i < len(messages):
user_content: List[BedrockContentBlock] = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "user":
if isinstance(messages[msg_i]["content"], list):
_parts: List[BedrockContentBlock] = []
for element in messages[msg_i]["content"]:
if isinstance(element, dict):
if element["type"] == "text":
_part = BedrockContentBlock(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
_parts.append(BedrockContentBlock(image=_part)) # type: ignore
user_content.extend(_parts)
else:
_part = BedrockContentBlock(text=messages[msg_i]["content"])
user_content.append(_part)
msg_i += 1
if user_content:
contents.append(BedrockMessageBlock(role="user", content=user_content))
assistant_content: List[BedrockContentBlock] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
if isinstance(messages[msg_i]["content"], list):
assistants_parts: List[BedrockContentBlock] = []
for element in messages[msg_i]["content"]:
if isinstance(element, dict):
if element["type"] == "text":
assistants_part = BedrockContentBlock(text=element["text"])
assistants_parts.append(assistants_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
assistants_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
assistants_parts.append(
BedrockContentBlock(image=assistants_part) # type: ignore
)
assistant_content.extend(assistants_parts)
elif messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke convertion
assistant_content.extend(
_convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"])
)
else:
assistant_text = (
messages[msg_i].get("content") or ""
) # either string or none
if assistant_text:
assistant_content.append(BedrockContentBlock(text=assistant_text))
msg_i += 1
if assistant_content:
contents.append(
BedrockMessageBlock(role="assistant", content=assistant_content)
)
## APPEND TOOL CALL MESSAGES ##
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
contents.append(tool_call_result)
msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
messages[msg_i]
)
)
return contents
def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
"""
OpenAI tools looks like:
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
}
]
"""
"""
Bedrock toolConfig looks like:
"tools": [
{
"toolSpec": {
"name": "top_song",
"description": "Get the most popular song played on a radio station.",
"inputSchema": {
"json": {
"type": "object",
"properties": {
"sign": {
"type": "string",
"description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP."
}
},
"required": [
"sign"
]
}
}
}
}
]
"""
tool_block_list: List[BedrockToolBlock] = []
for tool in tools:
parameters = tool.get("function", {}).get("parameters", None)
name = tool.get("function", {}).get("name", "")
description = tool.get("function", {}).get("description", "")
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
tool_spec = BedrockToolSpecBlock(
inputSchema=tool_input_schema, name=name, description=description
)
tool_block = BedrockToolBlock(toolSpec=tool_spec)
tool_block_list.append(tool_block)
return tool_block_list
# Function call template
def function_call_prompt(messages: list, functions: list):
function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:"""

View file

@ -12,6 +12,7 @@ from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_result,
convert_to_gemini_tool_call_invoke,
)
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type
class VertexAIError(Exception):
@ -297,29 +298,31 @@ def _convert_gemini_role(role: str) -> Literal["user", "model"]:
def _process_gemini_image(image_url: str) -> PartType:
try:
if ".mp4" in image_url and "gs://" in image_url:
# Case 1: Videos with Cloud Storage URIs
part_mime = "video/mp4"
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
return PartType(file_data=_file_data)
elif ".pdf" in image_url and "gs://" in image_url:
# Case 2: PDF's with Cloud Storage URIs
part_mime = "application/pdf"
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
return PartType(file_data=_file_data)
elif "gs://" in image_url:
# Case 3: Images with Cloud Storage URIs
# The supported MIME types for images include image/png and image/jpeg.
part_mime = "image/png" if "png" in image_url else "image/jpeg"
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
return PartType(file_data=_file_data)
# GCS URIs
if "gs://" in image_url:
# Figure out file type
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
extension = extension_with_dot[1:] # Ex: "png"
file_type = get_file_type_from_extension(extension)
# Validate the file type is supported by Gemini
if not is_gemini_1_5_accepted_file_type(file_type):
raise Exception(f"File type not supported by gemini - {file_type}")
mime_type = get_file_mime_type_for_file_type(file_type)
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
return PartType(file_data=file_data)
# Direct links
elif "https:/" in image_url:
# Case 4: Images with direct links
image = _load_image_from_url(image_url)
_blob = BlobType(data=image.data, mime_type=image._mime_type)
return PartType(inline_data=_blob)
# Base64 encoding
elif "base64" in image_url:
# Case 5: Images with base64 encoding
import base64, re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
@ -426,112 +429,6 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
return contents
def _gemini_vision_convert_messages(messages: list):
"""
Converts given messages for GPT-4 Vision to Gemini format.
Args:
messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type:
- If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt.
- If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images.
Returns:
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
Raises:
VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed.
Exception: If any other exception occurs during the execution of the function.
Note:
This function is based on the code from the 'gemini/getting-started/intro_gemini_python.ipynb' notebook in the 'generative-ai' repository on GitHub.
The supported MIME types for images include 'image/png' and 'image/jpeg'.
Examples:
>>> messages = [
... {"content": "Hello, world!"},
... {"content": [{"type": "text", "text": "This is a text message."}, {"type": "image_url", "image_url": "example.com/image.png"}]},
... ]
>>> _gemini_vision_convert_messages(messages)
('Hello, world!This is a text message.', [<Part object>, <Part object>])
"""
try:
import vertexai
except:
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
try:
from vertexai.preview.language_models import (
ChatModel,
CodeChatModel,
InputOutputTextPair,
)
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
Image,
)
# given messages for gpt-4 vision, convert them for gemini
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
prompt = ""
images = []
for message in messages:
if isinstance(message["content"], str):
prompt += message["content"]
elif isinstance(message["content"], list):
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
for element in message["content"]:
if isinstance(element, dict):
if element["type"] == "text":
prompt += element["text"]
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
images.append(image_url)
# processing images passed to gemini
processed_images = []
for img in images:
if "gs://" in img:
# Case 1: Images with Cloud Storage URIs
# The supported MIME types for images include image/png and image/jpeg.
part_mime = "image/png" if "png" in img else "image/jpeg"
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
processed_images.append(google_clooud_part)
elif "https:/" in img:
# Case 2: Images with direct links
image = _load_image_from_url(img)
processed_images.append(image)
elif ".mp4" in img and "gs://" in img:
# Case 3: Videos with Cloud Storage URIs
part_mime = "video/mp4"
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
processed_images.append(google_clooud_part)
elif "base64" in img:
# Case 4: Images with base64 encoding
import base64, re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = img.split(",")
# read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match:
mime_type = mime_type_match.group(1)
else:
mime_type = "image/jpeg"
decoded_img = base64.b64decode(img_without_base_64)
processed_image = Part.from_data(data=decoded_img, mime_type=mime_type)
processed_images.append(processed_image)
return prompt, processed_images
except Exception as e:
raise e
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
return _cache_key

View file

@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
@ -122,6 +122,7 @@ huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
####### COMPLETION ENDPOINTS ################
@ -2107,22 +2108,40 @@ def completion(
logging_obj=logging,
)
else:
response = bedrock_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
client=client,
)
if model.startswith("anthropic"):
response = bedrock_converse_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
client=client,
)
else:
response = bedrock_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
client=client,
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(

View file

@ -1,11 +1,20 @@
from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict
from dataclasses import fields
import enum
from typing import Optional, List, Union, Dict, Literal, Any, TypedDict
from typing import Optional, List, Union, Dict, Literal, Any, TypedDict, TYPE_CHECKING
from datetime import datetime
import uuid, json, sys, os
from litellm.types.router import UpdateRouterConfig
from litellm.types.utils import ProviderField
from typing_extensions import Annotated
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class LitellmUserRoles(str, enum.Enum):
@ -1195,6 +1204,7 @@ class UserAPIKeyAuth(
]
] = None
allowed_model_region: Optional[Literal["eu"]] = None
parent_otel_span: Optional[Span] = None
@model_validator(mode="before")
@classmethod
@ -1207,6 +1217,9 @@ class UserAPIKeyAuth(
values.update({"api_key": hash_token(values.get("api_key"))})
return values
class Config:
arbitrary_types_allowed = True
class LiteLLM_Config(LiteLLMBase):
param_name: str

View file

@ -17,10 +17,19 @@ from litellm.proxy._types import (
LiteLLM_OrganizationTable,
LitellmUserRoles,
)
from typing import Optional, Literal, Union
from litellm.proxy.utils import PrismaClient
from typing import Optional, Literal, TYPE_CHECKING, Any
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
from litellm.caching import DualCache
import litellm
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from datetime import datetime
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
@ -216,10 +225,13 @@ def get_actual_routes(allowed_routes: list) -> list:
return actual_routes
@log_to_opentelemetry
async def get_end_user_object(
end_user_id: Optional[str],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional[LiteLLM_EndUserTable]:
"""
Returns end user object, if in db.
@ -279,11 +291,14 @@ async def get_end_user_object(
return None
@log_to_opentelemetry
async def get_user_object(
user_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
user_id_upsert: bool,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional[LiteLLM_UserTable]:
"""
- Check if user id in proxy User Table
@ -330,10 +345,13 @@ async def get_user_object(
)
@log_to_opentelemetry
async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> LiteLLM_TeamTable:
"""
- Check if team id in proxy Team Table
@ -372,10 +390,13 @@ async def get_team_object(
)
@log_to_opentelemetry
async def get_org_object(
org_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
):
"""
- Check if org id in proxy Org Table

View file

@ -0,0 +1,130 @@
import copy
from fastapi import Request
from typing import Any, Dict, Optional, TYPE_CHECKING
from litellm.proxy._types import UserAPIKeyAuth
from litellm._logging import verbose_proxy_logger, verbose_logger
if TYPE_CHECKING:
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
ProxyConfig = _ProxyConfig
else:
ProxyConfig = Any
def parse_cache_control(cache_control):
cache_dict = {}
directives = cache_control.split(", ")
for directive in directives:
if "=" in directive:
key, value = directive.split("=")
cache_dict[key] = value
else:
cache_dict[directive] = True
return cache_dict
async def add_litellm_data_to_request(
data: dict,
request: Request,
user_api_key_dict: UserAPIKeyAuth,
proxy_config: ProxyConfig,
general_settings: Optional[Dict[str, Any]] = None,
version: Optional[str] = None,
):
"""
Adds LiteLLM-specific data to the request.
Args:
data (dict): The data dictionary to be modified.
request (Request): The incoming request.
user_api_key_dict (UserAPIKeyAuth): The user API key dictionary.
general_settings (Optional[Dict[str, Any]], optional): General settings. Defaults to None.
version (Optional[str], optional): Version. Defaults to None.
Returns:
dict: The modified data dictionary.
"""
query_params = dict(request.query_params)
if "api-version" in query_params:
data["api_version"] = query_params["api-version"]
# Include original request and headers in the data
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
## Cache Controls
headers = request.headers
verbose_proxy_logger.debug("Request Headers: %s", headers)
cache_control_header = headers.get("Cache-Control", None)
if cache_control_header:
cache_dict = parse_cache_control(cache_control_header)
data["ttl"] = cache_dict.get("s-maxage")
verbose_proxy_logger.debug("receiving data: %s", data)
# users can pass in 'user' param to /chat/completions. Don't override it
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
# if users are using user_api_key_auth, set `user` in `data`
data["user"] = user_api_key_dict.user_id
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_end_user_max_budget"] = getattr(
user_api_key_dict, "end_user_max_budget", None
)
data["metadata"]["litellm_api_version"] = version
if general_settings is not None:
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["endpoint"] = str(request.url)
# Add the OTEL Parent Trace before sending it LiteLLM
data["metadata"]["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span
### END-USER SPECIFIC PARAMS ###
if user_api_key_dict.allowed_model_region is not None:
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
return data

View file

@ -21,10 +21,14 @@ model_list:
general_settings:
master_key: sk-1234
alerting: ["slack"]
litellm_settings:
callbacks: ["otel"]
store_audit_logs: true
redact_messages_in_exceptions: True
enforced_params:
- user
- metadata
- metadata.generation_name
litellm_settings:
store_audit_logs: true

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
from typing import Optional, List, Any, Literal, Union
from typing import Optional, List, Any, Literal, Union, TYPE_CHECKING
import os
import subprocess
import hashlib
@ -46,6 +46,15 @@ from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta
from litellm.integrations.slack_alerting import SlackAlerting
from typing_extensions import overload
from functools import wraps
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
def print_verbose(print_statement):
@ -63,6 +72,58 @@ def print_verbose(print_statement):
print(f"LiteLLM Proxy: {print_statement}") # noqa
def safe_deep_copy(data):
"""
Safe Deep Copy
The LiteLLM Request has some object that can-not be pickled / deep copied
Use this function to safely deep copy the LiteLLM Request
"""
# Step 1: Remove the litellm_parent_otel_span
if isinstance(data, dict):
# remove litellm_parent_otel_span since this is not picklable
if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]:
litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span")
new_data = copy.deepcopy(data)
# Step 2: re-add the litellm_parent_otel_span after doing a deep copy
if isinstance(data, dict):
if "metadata" in data:
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span
return new_data
def log_to_opentelemetry(func):
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = datetime.now()
result = await func(*args, **kwargs)
end_time = datetime.now()
# Log to OTEL only if "parent_otel_span" is in kwargs and is not None
if (
"parent_otel_span" in kwargs
and kwargs["parent_otel_span"] is not None
and "proxy_logging_obj" in kwargs
and kwargs["proxy_logging_obj"] is not None
):
proxy_logging_obj = kwargs["proxy_logging_obj"]
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs["parent_otel_span"],
duration=0.0,
start_time=start_time,
end_time=end_time,
)
# end of logging to otel
return result
return wrapper
### LOGGING ###
class ProxyLogging:
"""
@ -282,7 +343,7 @@ class ProxyLogging:
"""
Runs the CustomLogger's async_moderation_hook()
"""
new_data = copy.deepcopy(data)
new_data = safe_deep_copy(data)
for callback in litellm.callbacks:
try:
if isinstance(callback, CustomLogger):
@ -832,6 +893,7 @@ class PrismaClient:
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
@log_to_opentelemetry
async def get_data(
self,
token: Optional[Union[str, list]] = None,
@ -858,6 +920,8 @@ class PrismaClient:
limit: Optional[
int
] = None, # pagination, number of rows to getch when find_all==True
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
):
args_passed_in = locals()
start_time = time.time()
@ -2829,6 +2893,10 @@ missing_keys_html_form = """
"""
def _to_ns(dt):
return int(dt.timestamp() * 1e9)
def get_error_message_str(e: Exception) -> str:
error_message = ""
if isinstance(e, HTTPException):

2
litellm/py.typed Normal file
View file

@ -0,0 +1,2 @@
# Marker file to instruct type checkers to look for inline type annotations in this package.
# See PEP 561 for more information.

File diff suppressed because it is too large Load diff

View file

@ -243,6 +243,7 @@ def test_completion_bedrock_claude_sts_oidc_auth():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skipif(
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
reason="Cannot run without being in CircleCI Runner",
@ -277,7 +278,15 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_bedrock_claude_3():
@pytest.mark.parametrize(
"image_url",
[
"",
"https://avatars.githubusercontent.com/u/29436595?v=",
],
)
def test_bedrock_claude_3(image_url):
try:
litellm.set_verbose = True
data = {
@ -294,7 +303,7 @@ def test_bedrock_claude_3():
{
"image_url": {
"detail": "high",
"url": "",
"url": image_url,
},
"type": "image_url",
},
@ -313,7 +322,6 @@ def test_bedrock_claude_3():
# Add any assertions here to check the response
assert len(response.choices) > 0
assert len(response.choices[0].message.content) > 0
except RateLimitError:
pass
except Exception as e:
@ -552,7 +560,7 @@ def test_bedrock_ptu():
assert "url" in mock_client_post.call_args.kwargs
assert (
mock_client_post.call_args.kwargs["url"]
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke"
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/converse"
)
mock_client_post.assert_called_once()

View file

@ -300,7 +300,11 @@ def test_completion_claude_3():
pytest.fail(f"Error occurred: {e}")
def test_completion_claude_3_function_call():
@pytest.mark.parametrize(
"model",
["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"],
)
def test_completion_claude_3_function_call(model):
litellm.set_verbose = True
tools = [
{
@ -331,13 +335,14 @@ def test_completion_claude_3_function_call():
try:
# test without max tokens
response = completion(
model="anthropic/claude-3-opus-20240229",
model=model,
messages=messages,
tools=tools,
tool_choice={
"type": "function",
"function": {"name": "get_current_weather"},
},
drop_params=True,
)
# Add any assertions, here to check response args
@ -364,10 +369,11 @@ def test_completion_claude_3_function_call():
)
# In the second response, Claude should deduce answer from tool results
second_response = completion(
model="anthropic/claude-3-opus-20240229",
model=model,
messages=messages,
tools=tools,
tool_choice="auto",
drop_params=True,
)
print(second_response)
except Exception as e:
@ -1398,7 +1404,6 @@ def test_hf_test_completion_tgi():
def mock_post(url, data=None, json=None, headers=None):
print(f"url={url}")
if "text-classification" in url:
raise Exception("Model not found")
@ -2241,9 +2246,6 @@ def test_re_use_openaiClient():
pytest.fail("got Exception", e)
# test_re_use_openaiClient()
def test_completion_azure():
try:
print("azure gpt-3.5 test\n\n")

View file

@ -15,6 +15,7 @@ from litellm.llms.prompt_templates.factory import (
claude_2_1_pt,
llama_2_chat_pt,
prompt_factory,
_bedrock_tools_pt,
)
@ -128,3 +129,27 @@ def test_anthropic_messages_pt():
# codellama_prompt_format()
def test_bedrock_tool_calling_pt():
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
converted_tools = _bedrock_tools_pt(tools=tools)
print(converted_tools)

View file

@ -210,7 +210,9 @@ def test_chat_completion_exception_any_model(client):
)
assert isinstance(openai_exception, openai.BadRequestError)
_error_message = openai_exception.message
assert "chat_completion: Invalid model name passed in model=Lite-GPT-12" in str(_error_message)
assert "chat_completion: Invalid model name passed in model=Lite-GPT-12" in str(
_error_message
)
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
@ -238,7 +240,9 @@ def test_embedding_exception_any_model(client):
print("Exception raised=", openai_exception)
assert isinstance(openai_exception, openai.BadRequestError)
_error_message = openai_exception.message
assert "embeddings: Invalid model name passed in model=Lite-GPT-12" in str(_error_message)
assert "embeddings: Invalid model name passed in model=Lite-GPT-12" in str(
_error_message
)
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")

View file

@ -1284,18 +1284,18 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
# pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize("sync_mode", [True]) # False
@pytest.mark.parametrize(
"model",
[
# "bedrock/cohere.command-r-plus-v1:0",
# "anthropic.claude-3-sonnet-20240229-v1:0",
# "anthropic.claude-instant-v1",
# "bedrock/ai21.j2-mid",
# "mistral.mistral-7b-instruct-v0:2",
# "bedrock/amazon.titan-tg1-large",
# "meta.llama3-8b-instruct-v1:0",
"cohere.command-text-v14"
"bedrock/cohere.command-r-plus-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-instant-v1",
"bedrock/ai21.j2-mid",
"mistral.mistral-7b-instruct-v0:2",
"bedrock/amazon.titan-tg1-large",
"meta.llama3-8b-instruct-v1:0",
"cohere.command-text-v14",
],
)
@pytest.mark.asyncio

281
litellm/types/files.py Normal file
View file

@ -0,0 +1,281 @@
from enum import Enum
from types import MappingProxyType
from typing import List, Set
"""
Base Enums/Consts
"""
class FileType(Enum):
AAC = "AAC"
CSV = "CSV"
DOC = "DOC"
DOCX = "DOCX"
FLAC = "FLAC"
FLV = "FLV"
GIF = "GIF"
GOOGLE_DOC = "GOOGLE_DOC"
GOOGLE_DRAWINGS = "GOOGLE_DRAWINGS"
GOOGLE_SHEETS = "GOOGLE_SHEETS"
GOOGLE_SLIDES = "GOOGLE_SLIDES"
HEIC = "HEIC"
HEIF = "HEIF"
HTML = "HTML"
JPEG = "JPEG"
JSON = "JSON"
M4A = "M4A"
M4V = "M4V"
MOV = "MOV"
MP3 = "MP3"
MP4 = "MP4"
MPEG = "MPEG"
MPEGPS = "MPEGPS"
MPG = "MPG"
MPA = "MPA"
MPGA = "MPGA"
OGG = "OGG"
OPUS = "OPUS"
PDF = "PDF"
PCM = "PCM"
PNG = "PNG"
PPT = "PPT"
PPTX = "PPTX"
RTF = "RTF"
THREE_GPP = "3GPP"
TXT = "TXT"
WAV = "WAV"
WEBM = "WEBM"
WEBP = "WEBP"
WMV = "WMV"
XLS = "XLS"
XLSX = "XLSX"
FILE_EXTENSIONS: MappingProxyType[FileType, List[str]] = MappingProxyType(
{
FileType.AAC: ["aac"],
FileType.CSV: ["csv"],
FileType.DOC: ["doc"],
FileType.DOCX: ["docx"],
FileType.FLAC: ["flac"],
FileType.FLV: ["flv"],
FileType.GIF: ["gif"],
FileType.GOOGLE_DOC: ["gdoc"],
FileType.GOOGLE_DRAWINGS: ["gdraw"],
FileType.GOOGLE_SHEETS: ["gsheet"],
FileType.GOOGLE_SLIDES: ["gslides"],
FileType.HEIC: ["heic"],
FileType.HEIF: ["heif"],
FileType.HTML: ["html", "htm"],
FileType.JPEG: ["jpeg", "jpg"],
FileType.JSON: ["json"],
FileType.M4A: ["m4a"],
FileType.M4V: ["m4v"],
FileType.MOV: ["mov"],
FileType.MP3: ["mp3"],
FileType.MP4: ["mp4"],
FileType.MPEG: ["mpeg"],
FileType.MPEGPS: ["mpegps"],
FileType.MPG: ["mpg"],
FileType.MPA: ["mpa"],
FileType.MPGA: ["mpga"],
FileType.OGG: ["ogg"],
FileType.OPUS: ["opus"],
FileType.PDF: ["pdf"],
FileType.PCM: ["pcm"],
FileType.PNG: ["png"],
FileType.PPT: ["ppt"],
FileType.PPTX: ["pptx"],
FileType.RTF: ["rtf"],
FileType.THREE_GPP: ["3gpp"],
FileType.TXT: ["txt"],
FileType.WAV: ["wav"],
FileType.WEBM: ["webm"],
FileType.WEBP: ["webp"],
FileType.WMV: ["wmv"],
FileType.XLS: ["xls"],
FileType.XLSX: ["xlsx"],
}
)
FILE_MIME_TYPES: MappingProxyType[FileType, str] = MappingProxyType(
{
FileType.AAC: "audio/aac",
FileType.CSV: "text/csv",
FileType.DOC: "application/msword",
FileType.DOCX: "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
FileType.FLAC: "audio/flac",
FileType.FLV: "video/x-flv",
FileType.GIF: "image/gif",
FileType.GOOGLE_DOC: "application/vnd.google-apps.document",
FileType.GOOGLE_DRAWINGS: "application/vnd.google-apps.drawing",
FileType.GOOGLE_SHEETS: "application/vnd.google-apps.spreadsheet",
FileType.GOOGLE_SLIDES: "application/vnd.google-apps.presentation",
FileType.HEIC: "image/heic",
FileType.HEIF: "image/heif",
FileType.HTML: "text/html",
FileType.JPEG: "image/jpeg",
FileType.JSON: "application/json",
FileType.M4A: "audio/x-m4a",
FileType.M4V: "video/x-m4v",
FileType.MOV: "video/quicktime",
FileType.MP3: "audio/mpeg",
FileType.MP4: "video/mp4",
FileType.MPEG: "video/mpeg",
FileType.MPEGPS: "video/mpegps",
FileType.MPG: "video/mpg",
FileType.MPA: "audio/m4a",
FileType.MPGA: "audio/mpga",
FileType.OGG: "audio/ogg",
FileType.OPUS: "audio/opus",
FileType.PDF: "application/pdf",
FileType.PCM: "audio/pcm",
FileType.PNG: "image/png",
FileType.PPT: "application/vnd.ms-powerpoint",
FileType.PPTX: "application/vnd.openxmlformats-officedocument.presentationml.presentation",
FileType.RTF: "application/rtf",
FileType.THREE_GPP: "video/3gpp",
FileType.TXT: "text/plain",
FileType.WAV: "audio/wav",
FileType.WEBM: "video/webm",
FileType.WEBP: "image/webp",
FileType.WMV: "video/wmv",
FileType.XLS: "application/vnd.ms-excel",
FileType.XLSX: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
}
)
"""
Util Functions
"""
def get_file_mime_type_from_extension(extension: str) -> str:
for file_type, extensions in FILE_EXTENSIONS.items():
if extension in extensions:
return FILE_MIME_TYPES[file_type]
raise ValueError(f"Unknown mime type for extension: {extension}")
def get_file_extension_from_mime_type(mime_type: str) -> str:
for file_type, mime in FILE_MIME_TYPES.items():
if mime == mime_type:
return FILE_EXTENSIONS[file_type][0]
raise ValueError(f"Unknown extension for mime type: {mime_type}")
def get_file_type_from_extension(extension: str) -> FileType:
for file_type, extensions in FILE_EXTENSIONS.items():
if extension in extensions:
return file_type
raise ValueError(f"Unknown file type for extension: {extension}")
def get_file_extension_for_file_type(file_type: FileType) -> str:
return FILE_EXTENSIONS[file_type][0]
def get_file_mime_type_for_file_type(file_type: FileType) -> str:
return FILE_MIME_TYPES[file_type]
"""
FileType Type Groupings (Videos, Images, etc)
"""
# Images
IMAGE_FILE_TYPES = {
FileType.PNG,
FileType.JPEG,
FileType.GIF,
FileType.WEBP,
FileType.HEIC,
FileType.HEIF,
}
def is_image_file_type(file_type):
return file_type in IMAGE_FILE_TYPES
# Videos
VIDEO_FILE_TYPES = {
FileType.MOV,
FileType.MP4,
FileType.MPEG,
FileType.M4V,
FileType.FLV,
FileType.MPEGPS,
FileType.MPG,
FileType.WEBM,
FileType.WMV,
FileType.THREE_GPP,
}
def is_video_file_type(file_type):
return file_type in VIDEO_FILE_TYPES
# Audio
AUDIO_FILE_TYPES = {
FileType.AAC,
FileType.FLAC,
FileType.MP3,
FileType.MPA,
FileType.MPGA,
FileType.OPUS,
FileType.PCM,
FileType.WAV,
}
def is_audio_file_type(file_type):
return file_type in AUDIO_FILE_TYPES
# Text
TEXT_FILE_TYPES = {FileType.CSV, FileType.HTML, FileType.RTF, FileType.TXT}
def is_text_file_type(file_type):
return file_type in TEXT_FILE_TYPES
"""
Other FileType Groupings
"""
# Accepted file types for GEMINI 1.5 through Vertex AI
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#gemini-send-multimodal-samples-images-nodejs
GEMINI_1_5_ACCEPTED_FILE_TYPES: Set[FileType] = {
# Image
FileType.PNG,
FileType.JPEG,
# Audio
FileType.AAC,
FileType.FLAC,
FileType.MP3,
FileType.MPA,
FileType.MPGA,
FileType.OPUS,
FileType.PCM,
FileType.WAV,
# Video
FileType.FLV,
FileType.MOV,
FileType.MPEG,
FileType.MPEGPS,
FileType.MPG,
FileType.MP4,
FileType.WEBM,
FileType.WMV,
FileType.THREE_GPP,
# PDF
FileType.PDF,
}
def is_gemini_1_5_accepted_file_type(file_type: FileType) -> bool:
return file_type in GEMINI_1_5_ACCEPTED_FILE_TYPES

View file

@ -1,4 +1,4 @@
from typing import TypedDict, Any, Union, Optional
from typing import TypedDict, Any, Union, Optional, Literal, List
import json
from typing_extensions import (
Self,
@ -11,10 +11,137 @@ from typing_extensions import (
)
class SystemContentBlock(TypedDict):
text: str
class ImageSourceBlock(TypedDict):
bytes: Optional[str] # base 64 encoded string
class ImageBlock(TypedDict):
format: Literal["png", "jpeg", "gif", "webp"]
source: ImageSourceBlock
class ToolResultContentBlock(TypedDict, total=False):
image: ImageBlock
json: dict
text: str
class ToolResultBlock(TypedDict, total=False):
content: Required[List[ToolResultContentBlock]]
toolUseId: Required[str]
status: Literal["success", "error"]
class ToolUseBlock(TypedDict):
input: dict
name: str
toolUseId: str
class ContentBlock(TypedDict, total=False):
text: str
image: ImageBlock
toolResult: ToolResultBlock
toolUse: ToolUseBlock
class MessageBlock(TypedDict):
content: List[ContentBlock]
role: Literal["user", "assistant"]
class ConverseMetricsBlock(TypedDict):
latencyMs: float # time in ms
class ConverseResponseOutputBlock(TypedDict):
message: Optional[MessageBlock]
class ConverseTokenUsageBlock(TypedDict):
inputTokens: int
outputTokens: int
totalTokens: int
class ConverseResponseBlock(TypedDict):
additionalModelResponseFields: dict
metrics: ConverseMetricsBlock
output: ConverseResponseOutputBlock
stopReason: (
str # end_turn | tool_use | max_tokens | stop_sequence | content_filtered
)
usage: ConverseTokenUsageBlock
class ToolInputSchemaBlock(TypedDict):
json: Optional[dict]
class ToolSpecBlock(TypedDict, total=False):
inputSchema: Required[ToolInputSchemaBlock]
name: Required[str]
description: str
class ToolBlock(TypedDict):
toolSpec: Optional[ToolSpecBlock]
class SpecificToolChoiceBlock(TypedDict):
name: str
class ToolChoiceValuesBlock(TypedDict, total=False):
any: dict
auto: dict
tool: SpecificToolChoiceBlock
class ToolConfigBlock(TypedDict, total=False):
tools: Required[List[ToolBlock]]
toolChoice: Union[str, ToolChoiceValuesBlock]
class InferenceConfig(TypedDict, total=False):
maxTokens: int
stopSequences: List[str]
temperature: float
topP: float
class ToolBlockDeltaEvent(TypedDict):
input: str
class ContentBlockDeltaEvent(TypedDict, total=False):
"""
Either 'text' or 'toolUse' will be specified for Converse API streaming response.
"""
text: str
toolUse: ToolBlockDeltaEvent
class RequestObject(TypedDict, total=False):
additionalModelRequestFields: dict
additionalModelResponseFieldPaths: List[str]
inferenceConfig: InferenceConfig
messages: Required[List[MessageBlock]]
system: List[SystemContentBlock]
toolConfig: ToolConfigBlock
class GenericStreamingChunk(TypedDict):
text: Required[str]
tool_str: Required[str]
is_finished: Required[bool]
finish_reason: Required[str]
usage: Optional[ConverseTokenUsageBlock]
class Document(TypedDict):

View file

@ -293,3 +293,20 @@ class ListBatchRequest(TypedDict, total=False):
extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]]
timeout: Optional[float]
class ChatCompletionToolCallFunctionChunk(TypedDict):
name: str
arguments: str
class ChatCompletionToolCallChunk(TypedDict):
id: str
type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk
class ChatCompletionResponseMessage(TypedDict, total=False):
content: Optional[str]
tool_calls: List[ChatCompletionToolCallChunk]
role: Literal["assistant"]

View file

@ -3,7 +3,7 @@ from pydantic import BaseModel, Field
from typing import Optional
class ServiceTypes(enum.Enum):
class ServiceTypes(str, enum.Enum):
"""
Enum for litellm + litellm-adjacent services (redis/postgres/etc.)
"""

View file

@ -239,6 +239,8 @@ def map_finish_reason(
return "length"
elif finish_reason == "tool_use": # anthropic
return "tool_calls"
elif finish_reason == "content_filtered":
return "content_filter"
return finish_reason
@ -5655,19 +5657,29 @@ def get_optional_params(
optional_params["stream"] = stream
elif "anthropic" in model:
_check_valid_arg(supported_params=supported_params)
# anthropic params on bedrock
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
else: # bedrock httpx route
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif "amazon" in model: # amazon titan llms
_check_valid_arg(supported_params=supported_params)
@ -6445,20 +6457,7 @@ def get_supported_openai_params(
- None if unmapped
"""
if custom_llm_provider == "bedrock":
if model.startswith("anthropic.claude-3"):
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
elif model.startswith("anthropic"):
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
elif model.startswith("ai21"):
return ["max_tokens", "temperature", "top_p", "stream"]
elif model.startswith("amazon"):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
elif model.startswith("meta"):
return ["max_tokens", "temperature", "top_p", "stream"]
elif model.startswith("cohere"):
return ["stream", "temperature", "max_tokens"]
elif model.startswith("mistral"):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ollama":
return litellm.OllamaConfig().get_supported_openai_params()
elif custom_llm_provider == "ollama_chat":
@ -8558,7 +8557,11 @@ def exception_type(
extra_information = f"\nModel: {model}"
if _api_base:
extra_information += f"\nAPI Base: `{_api_base}`"
if messages and len(messages) > 0:
if (
messages
and len(messages) > 0
and litellm.redact_messages_in_exceptions is False
):
extra_information += f"\nMessages: `{messages}`"
if _model_group is not None:
@ -9124,7 +9127,7 @@ def exception_type(
if "Unable to locate credentials" in error_str:
exception_mapping_worked = True
raise BadRequestError(
message=f"SagemakerException - {error_str}",
message=f"litellm.BadRequestError: SagemakerException - {error_str}",
model=model,
llm_provider="sagemaker",
response=original_exception.response,
@ -9158,10 +9161,16 @@ def exception_type(
):
exception_mapping_worked = True
raise BadRequestError(
message=f"VertexAIException BadRequestError - {error_str}",
message=f"litellm.BadRequestError: VertexAIException - {error_str}",
model=model,
llm_provider="vertex_ai",
response=original_exception.response,
response=httpx.Response(
status_code=429,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
),
litellm_debug_info=extra_information,
)
elif (
@ -9169,12 +9178,19 @@ def exception_type(
or "Content has no parts." in error_str
):
exception_mapping_worked = True
raise APIError(
message=f"VertexAIException APIError - {error_str}",
raise litellm.InternalServerError(
message=f"litellm.InternalServerError: VertexAIException - {error_str}",
status_code=500,
model=model,
llm_provider="vertex_ai",
request=original_exception.request,
request=(
original_exception.request
if hasattr(original_exception, "request")
else httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
)
),
litellm_debug_info=extra_information,
)
elif "403" in error_str:
@ -9183,7 +9199,13 @@ def exception_type(
message=f"VertexAIException BadRequestError - {error_str}",
model=model,
llm_provider="vertex_ai",
response=original_exception.response,
response=httpx.Response(
status_code=429,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
),
litellm_debug_info=extra_information,
)
elif "The response was blocked." in error_str:
@ -9230,12 +9252,18 @@ def exception_type(
model=model,
llm_provider="vertex_ai",
litellm_debug_info=extra_information,
response=original_exception.response,
response=httpx.Response(
status_code=429,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
),
)
if original_exception.status_code == 500:
exception_mapping_worked = True
raise APIError(
message=f"VertexAIException APIError - {error_str}",
raise litellm.InternalServerError(
message=f"VertexAIException InternalServerError - {error_str}",
status_code=500,
model=model,
llm_provider="vertex_ai",
@ -11423,12 +11451,27 @@ class CustomStreamWrapper:
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "bedrock":
from litellm.types.llms.bedrock import GenericStreamingChunk
if self.received_finish_reason is not None:
raise StopIteration
response_obj = self.handle_bedrock_stream(chunk)
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk)
@ -11695,7 +11738,7 @@ class CustomStreamWrapper:
and hasattr(model_response, "usage")
and hasattr(model_response.usage, "prompt_tokens")
):
if self.sent_first_chunk == False:
if self.sent_first_chunk is False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
@ -11863,6 +11906,8 @@ class CustomStreamWrapper:
def __next__(self):
try:
if self.completion_stream is None:
self.fetch_sync_stream()
while True:
if (
isinstance(self.completion_stream, str)
@ -11937,6 +11982,14 @@ class CustomStreamWrapper:
custom_llm_provider=self.custom_llm_provider,
)
def fetch_sync_stream(self):
if self.completion_stream is None and self.make_call is not None:
# Call make_call to get the completion stream
self.completion_stream = self.make_call(client=litellm.module_level_client)
self._stream_iter = self.completion_stream.__iter__()
return self.completion_stream
async def fetch_stream(self):
if self.completion_stream is None and self.make_call is not None:
# Call make_call to get the completion stream

View file

@ -5,6 +5,10 @@ description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT"
readme = "README.md"
packages = [
{ include = "litellm" },
{ include = "litellm/py.typed"},
]
[tool.poetry.urls]
homepage = "https://litellm.ai"

View file

@ -1 +1,3 @@
ignore = ["F403", "F401"]
ignore = ["F405"]
extend-select = ["E501"]
line-length = 120