mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
Merge branch 'main' into litellm_security_fix
This commit is contained in:
commit
92841dfe1b
31 changed files with 2394 additions and 5332 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -59,3 +59,4 @@ myenv/*
|
||||||
litellm/proxy/_experimental/out/404/index.html
|
litellm/proxy/_experimental/out/404/index.html
|
||||||
litellm/proxy/_experimental/out/model_hub/index.html
|
litellm/proxy/_experimental/out/model_hub/index.html
|
||||||
litellm/proxy/_experimental/out/onboarding/index.html
|
litellm/proxy/_experimental/out/onboarding/index.html
|
||||||
|
litellm/tests/log.txt
|
||||||
|
|
|
@ -62,6 +62,23 @@ curl -X GET 'http://localhost:4000/health/services?service=slack' \
|
||||||
-H 'Authorization: Bearer sk-1234'
|
-H 'Authorization: Bearer sk-1234'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Advanced - Redacting Messages from Alerts
|
||||||
|
|
||||||
|
By default alerts show the `messages/input` passed to the LLM. If you want to redact this from slack alerting set the following setting on your config
|
||||||
|
|
||||||
|
|
||||||
|
```shell
|
||||||
|
general_settings:
|
||||||
|
alerting: ["slack"]
|
||||||
|
alert_types: ["spend_reports"]
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
redact_messages_in_exceptions: True
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Advanced - Opting into specific alert types
|
## Advanced - Opting into specific alert types
|
||||||
|
|
||||||
Set `alert_types` if you want to Opt into only specific alert types
|
Set `alert_types` if you want to Opt into only specific alert types
|
||||||
|
|
|
@ -5,7 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*
|
||||||
### INIT VARIABLES ###
|
### INIT VARIABLES ###
|
||||||
import threading, requests, os
|
import threading, requests, os
|
||||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
from litellm._logging import (
|
from litellm._logging import (
|
||||||
set_verbose,
|
set_verbose,
|
||||||
|
@ -60,6 +60,7 @@ _async_failure_callback: List[Callable] = (
|
||||||
pre_call_rules: List[Callable] = []
|
pre_call_rules: List[Callable] = []
|
||||||
post_call_rules: List[Callable] = []
|
post_call_rules: List[Callable] = []
|
||||||
turn_off_message_logging: Optional[bool] = False
|
turn_off_message_logging: Optional[bool] = False
|
||||||
|
redact_messages_in_exceptions: Optional[bool] = False
|
||||||
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
||||||
## end of callbacks #############
|
## end of callbacks #############
|
||||||
|
|
||||||
|
@ -233,6 +234,7 @@ max_end_user_budget: Optional[float] = None
|
||||||
#### RELIABILITY ####
|
#### RELIABILITY ####
|
||||||
request_timeout: float = 6000
|
request_timeout: float = 6000
|
||||||
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
|
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
|
||||||
|
module_level_client = HTTPHandler(timeout=request_timeout)
|
||||||
num_retries: Optional[int] = None # per model endpoint
|
num_retries: Optional[int] = None # per model endpoint
|
||||||
default_fallbacks: Optional[List] = None
|
default_fallbacks: Optional[List] = None
|
||||||
fallbacks: Optional[List] = None
|
fallbacks: Optional[List] = None
|
||||||
|
@ -766,7 +768,7 @@ from .llms.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
from .llms.maritalk import MaritTalkConfig
|
from .llms.maritalk import MaritTalkConfig
|
||||||
from .llms.bedrock_httpx import AmazonCohereChatConfig
|
from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig
|
||||||
from .llms.bedrock import (
|
from .llms.bedrock import (
|
||||||
AmazonTitanConfig,
|
AmazonTitanConfig,
|
||||||
AmazonAI21Config,
|
AmazonAI21Config,
|
||||||
|
|
|
@ -1,10 +1,18 @@
|
||||||
import litellm, traceback
|
from datetime import datetime
|
||||||
|
import litellm
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from .types.services import ServiceTypes, ServiceLoggerPayload
|
from .types.services import ServiceTypes, ServiceLoggerPayload
|
||||||
from .integrations.prometheus_services import PrometheusServicesLogger
|
from .integrations.prometheus_services import PrometheusServicesLogger
|
||||||
from .integrations.custom_logger import CustomLogger
|
from .integrations.custom_logger import CustomLogger
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Union
|
from typing import Union, Optional, TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
|
Span = _Span
|
||||||
|
else:
|
||||||
|
Span = Any
|
||||||
|
|
||||||
|
|
||||||
class ServiceLogging(CustomLogger):
|
class ServiceLogging(CustomLogger):
|
||||||
|
@ -40,7 +48,13 @@ class ServiceLogging(CustomLogger):
|
||||||
self.mock_testing_sync_failure_hook += 1
|
self.mock_testing_sync_failure_hook += 1
|
||||||
|
|
||||||
async def async_service_success_hook(
|
async def async_service_success_hook(
|
||||||
self, service: ServiceTypes, duration: float, call_type: str
|
self,
|
||||||
|
service: ServiceTypes,
|
||||||
|
call_type: str,
|
||||||
|
duration: float,
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
|
start_time: Optional[datetime] = None,
|
||||||
|
end_time: Optional[datetime] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
- For counting if the redis, postgres call is successful
|
- For counting if the redis, postgres call is successful
|
||||||
|
@ -61,6 +75,16 @@ class ServiceLogging(CustomLogger):
|
||||||
payload=payload
|
payload=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||||
|
|
||||||
|
if parent_otel_span is not None and open_telemetry_logger is not None:
|
||||||
|
await open_telemetry_logger.async_service_success_hook(
|
||||||
|
payload=payload,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
|
||||||
async def async_service_failure_hook(
|
async def async_service_failure_hook(
|
||||||
self,
|
self,
|
||||||
service: ServiceTypes,
|
service: ServiceTypes,
|
||||||
|
|
|
@ -1,9 +1,21 @@
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.types.services import ServiceLoggerPayload
|
||||||
|
from typing import Union, Optional, TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from opentelemetry.trace import Span as _Span
|
||||||
|
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
|
||||||
|
|
||||||
|
Span = _Span
|
||||||
|
UserAPIKeyAuth = _UserAPIKeyAuth
|
||||||
|
else:
|
||||||
|
Span = Any
|
||||||
|
UserAPIKeyAuth = Any
|
||||||
|
|
||||||
|
|
||||||
LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm")
|
LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm")
|
||||||
|
@ -77,6 +89,56 @@ class OpenTelemetry(CustomLogger):
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
self._handle_failure(kwargs, response_obj, start_time, end_time)
|
self._handle_failure(kwargs, response_obj, start_time, end_time)
|
||||||
|
|
||||||
|
async def async_service_success_hook(
|
||||||
|
self,
|
||||||
|
payload: ServiceLoggerPayload,
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
|
start_time: Optional[datetime] = None,
|
||||||
|
end_time: Optional[datetime] = None,
|
||||||
|
):
|
||||||
|
from opentelemetry import trace
|
||||||
|
from datetime import datetime
|
||||||
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
|
if parent_otel_span is not None:
|
||||||
|
_span_name = payload.service
|
||||||
|
service_logging_span = self.tracer.start_span(
|
||||||
|
name=_span_name,
|
||||||
|
context=trace.set_span_in_context(parent_otel_span),
|
||||||
|
start_time=self._to_ns(start_time),
|
||||||
|
)
|
||||||
|
service_logging_span.set_attribute(key="call_type", value=payload.call_type)
|
||||||
|
service_logging_span.set_attribute(
|
||||||
|
key="service", value=payload.service.value
|
||||||
|
)
|
||||||
|
service_logging_span.set_status(Status(StatusCode.OK))
|
||||||
|
service_logging_span.end(end_time=self._to_ns(end_time))
|
||||||
|
|
||||||
|
async def async_post_call_failure_hook(
|
||||||
|
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
||||||
|
):
|
||||||
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
from opentelemetry import trace
|
||||||
|
|
||||||
|
parent_otel_span = user_api_key_dict.parent_otel_span
|
||||||
|
if parent_otel_span is not None:
|
||||||
|
parent_otel_span.set_status(Status(StatusCode.ERROR))
|
||||||
|
_span_name = "Failed Proxy Server Request"
|
||||||
|
|
||||||
|
# Exception Logging Child Span
|
||||||
|
exception_logging_span = self.tracer.start_span(
|
||||||
|
name=_span_name,
|
||||||
|
context=trace.set_span_in_context(parent_otel_span),
|
||||||
|
)
|
||||||
|
exception_logging_span.set_attribute(
|
||||||
|
key="exception", value=str(original_exception)
|
||||||
|
)
|
||||||
|
exception_logging_span.set_status(Status(StatusCode.ERROR))
|
||||||
|
exception_logging_span.end(end_time=self._to_ns(datetime.now()))
|
||||||
|
|
||||||
|
# End Parent OTEL Sspan
|
||||||
|
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||||
|
|
||||||
def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
|
def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
|
@ -85,15 +147,18 @@ class OpenTelemetry(CustomLogger):
|
||||||
kwargs,
|
kwargs,
|
||||||
self.config,
|
self.config,
|
||||||
)
|
)
|
||||||
|
_parent_context, parent_otel_span = self._get_span_context(kwargs)
|
||||||
|
|
||||||
span = self.tracer.start_span(
|
span = self.tracer.start_span(
|
||||||
name=self._get_span_name(kwargs),
|
name=self._get_span_name(kwargs),
|
||||||
start_time=self._to_ns(start_time),
|
start_time=self._to_ns(start_time),
|
||||||
context=self._get_span_context(kwargs),
|
context=_parent_context,
|
||||||
)
|
)
|
||||||
span.set_status(Status(StatusCode.OK))
|
span.set_status(Status(StatusCode.OK))
|
||||||
self.set_attributes(span, kwargs, response_obj)
|
self.set_attributes(span, kwargs, response_obj)
|
||||||
span.end(end_time=self._to_ns(end_time))
|
span.end(end_time=self._to_ns(end_time))
|
||||||
|
if parent_otel_span is not None:
|
||||||
|
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||||
|
|
||||||
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
|
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
@ -122,17 +187,28 @@ class OpenTelemetry(CustomLogger):
|
||||||
from opentelemetry.trace.propagation.tracecontext import (
|
from opentelemetry.trace.propagation.tracecontext import (
|
||||||
TraceContextTextMapPropagator,
|
TraceContextTextMapPropagator,
|
||||||
)
|
)
|
||||||
|
from opentelemetry import trace
|
||||||
|
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
|
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
|
||||||
headers = proxy_server_request.get("headers", {}) or {}
|
headers = proxy_server_request.get("headers", {}) or {}
|
||||||
traceparent = headers.get("traceparent", None)
|
traceparent = headers.get("traceparent", None)
|
||||||
|
_metadata = litellm_params.get("metadata", {})
|
||||||
|
parent_otel_span = _metadata.get("litellm_parent_otel_span", None)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Two way to use parents in opentelemetry
|
||||||
|
- using the traceparent header
|
||||||
|
- using the parent_otel_span in the [metadata][parent_otel_span]
|
||||||
|
"""
|
||||||
|
if parent_otel_span is not None:
|
||||||
|
return trace.set_span_in_context(parent_otel_span), parent_otel_span
|
||||||
|
|
||||||
if traceparent is None:
|
if traceparent is None:
|
||||||
return None
|
return None, None
|
||||||
else:
|
else:
|
||||||
carrier = {"traceparent": traceparent}
|
carrier = {"traceparent": traceparent}
|
||||||
return TraceContextTextMapPropagator().extract(carrier=carrier)
|
return TraceContextTextMapPropagator().extract(carrier=carrier), None
|
||||||
|
|
||||||
def _get_span_processor(self):
|
def _get_span_processor(self):
|
||||||
from opentelemetry.sdk.trace.export import (
|
from opentelemetry.sdk.trace.export import (
|
||||||
|
|
|
@ -326,8 +326,8 @@ class SlackAlerting(CustomLogger):
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if litellm.turn_off_message_logging:
|
if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions:
|
||||||
messages = "Message not logged. `litellm.turn_off_message_logging=True`."
|
messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
|
||||||
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
||||||
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
||||||
if time_difference_float > self.alerting_threshold:
|
if time_difference_float > self.alerting_threshold:
|
||||||
|
@ -567,9 +567,12 @@ class SlackAlerting(CustomLogger):
|
||||||
except:
|
except:
|
||||||
messages = ""
|
messages = ""
|
||||||
|
|
||||||
if litellm.turn_off_message_logging:
|
if (
|
||||||
|
litellm.turn_off_message_logging
|
||||||
|
or litellm.redact_messages_in_exceptions
|
||||||
|
):
|
||||||
messages = (
|
messages = (
|
||||||
"Message not logged. `litellm.turn_off_message_logging=True`."
|
"Message not logged. litellm.redact_messages_in_exceptions=True"
|
||||||
)
|
)
|
||||||
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
|
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -38,6 +38,8 @@ from .prompt_templates.factory import (
|
||||||
extract_between_tags,
|
extract_between_tags,
|
||||||
parse_xml_params,
|
parse_xml_params,
|
||||||
contains_tag,
|
contains_tag,
|
||||||
|
_bedrock_converse_messages_pt,
|
||||||
|
_bedrock_tools_pt,
|
||||||
)
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
@ -45,6 +47,11 @@ import httpx # type: ignore
|
||||||
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
|
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
|
||||||
from litellm.types.llms.bedrock import *
|
from litellm.types.llms.bedrock import *
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
ChatCompletionResponseMessage,
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AmazonCohereChatConfig:
|
class AmazonCohereChatConfig:
|
||||||
|
@ -118,6 +125,8 @@ class AmazonCohereChatConfig:
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
"seed",
|
"seed",
|
||||||
"stop",
|
"stop",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
|
@ -169,7 +178,38 @@ async def make_call(
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
api_key="",
|
api_key="",
|
||||||
original_response=completion_stream, # Pass the completion stream for logging
|
original_response="first stream response received",
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
|
def make_sync_call(
|
||||||
|
client: Optional[HTTPHandler],
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
):
|
||||||
|
if client is None:
|
||||||
|
client = HTTPHandler() # Create a new client if none provided
|
||||||
|
|
||||||
|
response = client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise BedrockError(status_code=response.status_code, message=response.read())
|
||||||
|
|
||||||
|
decoder = AWSEventStreamDecoder(model=model)
|
||||||
|
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||||
|
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response="first stream response received",
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1000,12 +1040,12 @@ class BedrockLLM(BaseLLM):
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
timeout = httpx.Timeout(timeout)
|
timeout = httpx.Timeout(timeout)
|
||||||
_params["timeout"] = timeout
|
_params["timeout"] = timeout
|
||||||
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
else:
|
else:
|
||||||
self.client = client # type: ignore
|
client = client # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
response = await client.post(api_base, headers=headers, data=data) # type: ignore
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
|
@ -1069,6 +1109,738 @@ class BedrockLLM(BaseLLM):
|
||||||
return super().embedding(*args, **kwargs)
|
return super().embedding(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonConverseConfig:
|
||||||
|
"""
|
||||||
|
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
|
||||||
|
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
|
||||||
|
"""
|
||||||
|
|
||||||
|
maxTokens: Optional[int]
|
||||||
|
stopSequences: Optional[List[str]]
|
||||||
|
temperature: Optional[int]
|
||||||
|
topP: Optional[int]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
maxTokens: Optional[int] = None,
|
||||||
|
stopSequences: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
topP: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
|
supported_params = [
|
||||||
|
"max_tokens",
|
||||||
|
"stream",
|
||||||
|
"stream_options",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"extra_headers",
|
||||||
|
]
|
||||||
|
|
||||||
|
if (
|
||||||
|
model.startswith("anthropic")
|
||||||
|
or model.startswith("mistral")
|
||||||
|
or model.startswith("cohere")
|
||||||
|
):
|
||||||
|
supported_params.append("tools")
|
||||||
|
|
||||||
|
if model.startswith("anthropic") or model.startswith("mistral"):
|
||||||
|
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||||
|
supported_params.append("tool_choice")
|
||||||
|
|
||||||
|
return supported_params
|
||||||
|
|
||||||
|
def map_tool_choice_values(
|
||||||
|
self, model: str, tool_choice: Union[str, dict], drop_params: bool
|
||||||
|
) -> Optional[ToolChoiceValuesBlock]:
|
||||||
|
if tool_choice == "none":
|
||||||
|
if litellm.drop_params is True or drop_params is True:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||||
|
tool_choice
|
||||||
|
),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
elif tool_choice == "required":
|
||||||
|
return ToolChoiceValuesBlock(any={})
|
||||||
|
elif tool_choice == "auto":
|
||||||
|
return ToolChoiceValuesBlock(auto={})
|
||||||
|
elif isinstance(tool_choice, dict):
|
||||||
|
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||||
|
specific_tool = SpecificToolChoiceBlock(
|
||||||
|
name=tool_choice.get("function", {}).get("name", "")
|
||||||
|
)
|
||||||
|
return ToolChoiceValuesBlock(tool=specific_tool)
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||||
|
tool_choice
|
||||||
|
),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_supported_image_types(self) -> List[str]:
|
||||||
|
return ["png", "jpeg", "gif", "webp"]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["maxTokens"] = value
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = [value]
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["topP"] = value
|
||||||
|
if param == "tools":
|
||||||
|
optional_params["tools"] = value
|
||||||
|
if param == "tool_choice":
|
||||||
|
_tool_choice_value = self.map_tool_choice_values(
|
||||||
|
model=model, tool_choice=value, drop_params=drop_params # type: ignore
|
||||||
|
)
|
||||||
|
if _tool_choice_value is not None:
|
||||||
|
optional_params["tool_choice"] = _tool_choice_value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockConverseLLM(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise BedrockError(
|
||||||
|
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||||
|
response.text, str(e)
|
||||||
|
),
|
||||||
|
status_code=422,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Bedrock Response Object has optional message block
|
||||||
|
|
||||||
|
completion_response["output"].get("message", None)
|
||||||
|
|
||||||
|
A message block looks like this (Example 1):
|
||||||
|
"output": {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
(Example 2):
|
||||||
|
"output": {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"toolUse": {
|
||||||
|
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
|
||||||
|
"name": "top_song",
|
||||||
|
"input": {
|
||||||
|
"sign": "WZPZ"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
message: Optional[MessageBlock] = completion_response["output"]["message"]
|
||||||
|
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
||||||
|
content_str = ""
|
||||||
|
tools: List[ChatCompletionToolCallChunk] = []
|
||||||
|
if message is not None:
|
||||||
|
for content in message["content"]:
|
||||||
|
"""
|
||||||
|
- Content is either a tool response or text
|
||||||
|
"""
|
||||||
|
if "text" in content:
|
||||||
|
content_str += content["text"]
|
||||||
|
if "toolUse" in content:
|
||||||
|
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||||
|
name=content["toolUse"]["name"],
|
||||||
|
arguments=json.dumps(content["toolUse"]["input"]),
|
||||||
|
)
|
||||||
|
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||||
|
id=content["toolUse"]["toolUseId"],
|
||||||
|
type="function",
|
||||||
|
function=_function_chunk,
|
||||||
|
)
|
||||||
|
tools.append(_tool_response_chunk)
|
||||||
|
chat_completion_message["content"] = content_str
|
||||||
|
chat_completion_message["tool_calls"] = tools
|
||||||
|
|
||||||
|
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||||
|
input_tokens = completion_response["usage"]["inputTokens"]
|
||||||
|
output_tokens = completion_response["usage"]["outputTokens"]
|
||||||
|
total_tokens = completion_response["usage"]["totalTokens"]
|
||||||
|
|
||||||
|
model_response.choices = [
|
||||||
|
litellm.Choices(
|
||||||
|
finish_reason=map_finish_reason(completion_response["stopReason"]),
|
||||||
|
index=0,
|
||||||
|
message=litellm.Message(**chat_completion_message),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=input_tokens,
|
||||||
|
completion_tokens=output_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def encode_model_id(self, model_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Double encode the model ID to ensure it matches the expected double-encoded format.
|
||||||
|
Args:
|
||||||
|
model_id (str): The model ID to encode.
|
||||||
|
Returns:
|
||||||
|
str: The double-encoded model ID.
|
||||||
|
"""
|
||||||
|
return urllib.parse.quote(model_id, safe="")
|
||||||
|
|
||||||
|
def get_credentials(
|
||||||
|
self,
|
||||||
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
aws_region_name: Optional[str] = None,
|
||||||
|
aws_session_name: Optional[str] = None,
|
||||||
|
aws_profile_name: Optional[str] = None,
|
||||||
|
aws_role_name: Optional[str] = None,
|
||||||
|
aws_web_identity_token: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return a boto3.Credentials object
|
||||||
|
"""
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
## CHECK IS 'os.environ/' passed in
|
||||||
|
params_to_check: List[Optional[str]] = [
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Iterate over parameters and update if needed
|
||||||
|
for i, param in enumerate(params_to_check):
|
||||||
|
if param and param.startswith("os.environ/"):
|
||||||
|
_v = get_secret(param)
|
||||||
|
if _v is not None and isinstance(_v, str):
|
||||||
|
params_to_check[i] = _v
|
||||||
|
# Assign updated values back to parameters
|
||||||
|
(
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
|
) = params_to_check
|
||||||
|
|
||||||
|
### CHECK STS ###
|
||||||
|
if (
|
||||||
|
aws_web_identity_token is not None
|
||||||
|
and aws_role_name is not None
|
||||||
|
and aws_session_name is not None
|
||||||
|
):
|
||||||
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise BedrockError(
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_client = boto3.client("sts")
|
||||||
|
|
||||||
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
sts_response = sts_client.assume_role_with_web_identity(
|
||||||
|
RoleArn=aws_role_name,
|
||||||
|
RoleSessionName=aws_session_name,
|
||||||
|
WebIdentityToken=oidc_token,
|
||||||
|
DurationSeconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
session = boto3.Session(
|
||||||
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||||
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||||
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||||
|
region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return session.get_credentials()
|
||||||
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||||
|
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_response = sts_client.assume_role(
|
||||||
|
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the credentials from the response and convert to Session Credentials
|
||||||
|
sts_credentials = sts_response["Credentials"]
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
|
credentials = Credentials(
|
||||||
|
access_key=sts_credentials["AccessKeyId"],
|
||||||
|
secret_key=sts_credentials["SecretAccessKey"],
|
||||||
|
token=sts_credentials["SessionToken"],
|
||||||
|
)
|
||||||
|
return credentials
|
||||||
|
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||||
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||||
|
client = boto3.Session(profile_name=aws_profile_name)
|
||||||
|
|
||||||
|
return client.get_credentials()
|
||||||
|
else:
|
||||||
|
session = boto3.Session(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return session.get_credentials()
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_call,
|
||||||
|
client=client,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(api_base, headers=headers, data=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||||
|
except httpx.TimeoutException as e:
|
||||||
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream if isinstance(stream, bool) else False,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion: bool,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
|
## SETUP ##
|
||||||
|
stream = optional_params.pop("stream", None)
|
||||||
|
modelId = optional_params.pop("model_id", None)
|
||||||
|
if modelId is not None:
|
||||||
|
modelId = self.encode_model_id(model_id=modelId)
|
||||||
|
else:
|
||||||
|
modelId = model
|
||||||
|
|
||||||
|
provider = model.split(".")[0]
|
||||||
|
|
||||||
|
## CREDENTIALS ##
|
||||||
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
|
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||||
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
|
"aws_bedrock_runtime_endpoint", None
|
||||||
|
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
|
### SET REGION NAME ###
|
||||||
|
if aws_region_name is None:
|
||||||
|
# check env #
|
||||||
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
|
||||||
|
if litellm_aws_region_name is not None and isinstance(
|
||||||
|
litellm_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = litellm_aws_region_name
|
||||||
|
|
||||||
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
if standard_aws_region_name is not None and isinstance(
|
||||||
|
standard_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = standard_aws_region_name
|
||||||
|
|
||||||
|
if aws_region_name is None:
|
||||||
|
aws_region_name = "us-west-2"
|
||||||
|
|
||||||
|
credentials: Credentials = self.get_credentials(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_session_name=aws_session_name,
|
||||||
|
aws_profile_name=aws_profile_name,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
endpoint_url = ""
|
||||||
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||||
|
if aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||||
|
aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = aws_bedrock_runtime_endpoint
|
||||||
|
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||||
|
env_aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||||
|
else:
|
||||||
|
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||||
|
|
||||||
|
if (stream is not None and stream is True) and provider != "ai21":
|
||||||
|
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||||
|
else:
|
||||||
|
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
||||||
|
|
||||||
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
|
||||||
|
# Separate system prompt from rest of message
|
||||||
|
system_prompt_indices = []
|
||||||
|
system_content_blocks: List[SystemContentBlock] = []
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "system":
|
||||||
|
_system_content_block = SystemContentBlock(text=message["content"])
|
||||||
|
system_content_blocks.append(_system_content_block)
|
||||||
|
system_prompt_indices.append(idx)
|
||||||
|
if len(system_prompt_indices) > 0:
|
||||||
|
for idx in reversed(system_prompt_indices):
|
||||||
|
messages.pop(idx)
|
||||||
|
|
||||||
|
inference_params = copy.deepcopy(optional_params)
|
||||||
|
additional_request_keys = []
|
||||||
|
additional_request_params = {}
|
||||||
|
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
|
||||||
|
supported_tool_call_params = ["tools", "tool_choice"]
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
# send all model-specific params in 'additional_request_params'
|
||||||
|
for k, v in inference_params.items():
|
||||||
|
if (
|
||||||
|
k not in supported_converse_params
|
||||||
|
and k not in supported_tool_call_params
|
||||||
|
):
|
||||||
|
additional_request_params[k] = v
|
||||||
|
additional_request_keys.append(k)
|
||||||
|
for key in additional_request_keys:
|
||||||
|
inference_params.pop(key, None)
|
||||||
|
|
||||||
|
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
|
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||||
|
inference_params.pop("tools", [])
|
||||||
|
)
|
||||||
|
bedrock_tool_config: Optional[ToolConfigBlock] = None
|
||||||
|
if len(bedrock_tools) > 0:
|
||||||
|
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
|
||||||
|
"tool_choice", None
|
||||||
|
)
|
||||||
|
bedrock_tool_config = ToolConfigBlock(
|
||||||
|
tools=bedrock_tools,
|
||||||
|
)
|
||||||
|
if tool_choice_values is not None:
|
||||||
|
bedrock_tool_config["toolChoice"] = tool_choice_values
|
||||||
|
|
||||||
|
_data: RequestObject = {
|
||||||
|
"messages": bedrock_messages,
|
||||||
|
"additionalModelRequestFields": additional_request_params,
|
||||||
|
"system": system_content_blocks,
|
||||||
|
"inferenceConfig": InferenceConfig(**inference_params),
|
||||||
|
}
|
||||||
|
if bedrock_tool_config is not None:
|
||||||
|
_data["toolConfig"] = bedrock_tool_config
|
||||||
|
data = json.dumps(_data)
|
||||||
|
## COMPLETION CALL
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
|
request = AWSRequest(
|
||||||
|
method="POST", url=endpoint_url, data=data, headers=headers
|
||||||
|
)
|
||||||
|
sigv4.add_auth(request)
|
||||||
|
prepped = request.prepare()
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": prepped.url,
|
||||||
|
"headers": prepped.headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
if acompletion:
|
||||||
|
if isinstance(client, HTTPHandler):
|
||||||
|
client = None
|
||||||
|
if stream is True and provider != "ai21":
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=prepped.url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=True,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=prepped.headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=prepped.url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream, # type: ignore
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=prepped.headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
if (stream is not None and stream is True) and provider != "ai21":
|
||||||
|
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_sync_call,
|
||||||
|
client=None,
|
||||||
|
api_base=prepped.url,
|
||||||
|
headers=prepped.headers, # type: ignore
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return streaming_response
|
||||||
|
### COMPLETION
|
||||||
|
|
||||||
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = HTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client
|
||||||
|
try:
|
||||||
|
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise BedrockError(status_code=error_code, message=response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_response_stream_shape():
|
def get_response_stream_shape():
|
||||||
from botocore.model import ServiceModel
|
from botocore.model import ServiceModel
|
||||||
from botocore.loaders import Loader
|
from botocore.loaders import Loader
|
||||||
|
@ -1086,6 +1858,31 @@ class AWSEventStreamDecoder:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.parser = EventStreamJSONParser()
|
self.parser = EventStreamJSONParser()
|
||||||
|
|
||||||
|
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||||
|
text = ""
|
||||||
|
tool_str = ""
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = ""
|
||||||
|
usage: Optional[ConverseTokenUsageBlock] = None
|
||||||
|
if "delta" in chunk_data:
|
||||||
|
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
|
||||||
|
if "text" in delta_obj:
|
||||||
|
text = delta_obj["text"]
|
||||||
|
elif "toolUse" in delta_obj:
|
||||||
|
tool_str = delta_obj["toolUse"]["input"]
|
||||||
|
elif "stopReason" in chunk_data:
|
||||||
|
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
||||||
|
elif "usage" in chunk_data:
|
||||||
|
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
|
||||||
|
response = GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
tool_str=tool_str,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
|
@ -1098,19 +1895,8 @@ class AWSEventStreamDecoder:
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
######## bedrock.anthropic mappings ###############
|
######## bedrock.anthropic mappings ###############
|
||||||
elif "completion" in chunk_data: # not claude-3
|
|
||||||
text = chunk_data["completion"] # bedrock.anthropic
|
|
||||||
stop_reason = chunk_data.get("stop_reason", None)
|
|
||||||
if stop_reason != None:
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = stop_reason
|
|
||||||
elif "delta" in chunk_data:
|
elif "delta" in chunk_data:
|
||||||
if chunk_data["delta"].get("text", None) is not None:
|
return self.converse_chunk_parser(chunk_data=chunk_data)
|
||||||
text = chunk_data["delta"]["text"]
|
|
||||||
stop_reason = chunk_data["delta"].get("stop_reason", None)
|
|
||||||
if stop_reason != None:
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = stop_reason
|
|
||||||
######## bedrock.mistral mappings ###############
|
######## bedrock.mistral mappings ###############
|
||||||
elif "outputs" in chunk_data:
|
elif "outputs" in chunk_data:
|
||||||
if (
|
if (
|
||||||
|
@ -1137,11 +1923,11 @@ class AWSEventStreamDecoder:
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = chunk_data["completionReason"]
|
finish_reason = chunk_data["completionReason"]
|
||||||
return GenericStreamingChunk(
|
return GenericStreamingChunk(
|
||||||
**{
|
text=text,
|
||||||
"text": text,
|
is_finished=is_finished,
|
||||||
"is_finished": is_finished,
|
finish_reason=finish_reason,
|
||||||
"finish_reason": finish_reason,
|
tool_str="",
|
||||||
}
|
usage=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||||
|
@ -1178,9 +1964,14 @@ class AWSEventStreamDecoder:
|
||||||
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
||||||
if response_dict["status_code"] != 200:
|
if response_dict["status_code"] != 200:
|
||||||
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
||||||
|
if "chunk" in parsed_response:
|
||||||
|
chunk = parsed_response.get("chunk")
|
||||||
|
if not chunk:
|
||||||
|
return None
|
||||||
|
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
||||||
|
else:
|
||||||
|
chunk = response_dict.get("body")
|
||||||
|
if not chunk:
|
||||||
|
return None
|
||||||
|
|
||||||
chunk = parsed_response.get("chunk")
|
return chunk.decode() # type: ignore[no-any-return]
|
||||||
if not chunk:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
|
||||||
|
|
|
@ -156,12 +156,13 @@ class HTTPHandler:
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Optional[Union[dict, str]] = None,
|
data: Optional[Union[dict, str]] = None,
|
||||||
|
json: Optional[Union[dict, str]] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
req = self.client.build_request(
|
req = self.client.build_request(
|
||||||
"POST", url, data=data, params=params, headers=headers # type: ignore
|
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||||
)
|
)
|
||||||
response = self.client.send(req, stream=stream)
|
response = self.client.send(req, stream=stream)
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -3,14 +3,7 @@ import requests, traceback
|
||||||
import json, re, xml.etree.ElementTree as ET
|
import json, re, xml.etree.ElementTree as ET
|
||||||
from jinja2 import Template, exceptions, meta, BaseLoader
|
from jinja2 import Template, exceptions, meta, BaseLoader
|
||||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||||
from typing import (
|
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
|
||||||
Any,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
MutableMapping,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
)
|
|
||||||
import litellm
|
import litellm
|
||||||
import litellm.types
|
import litellm.types
|
||||||
from litellm.types.completion import (
|
from litellm.types.completion import (
|
||||||
|
@ -24,7 +17,7 @@ from litellm.types.completion import (
|
||||||
import litellm.types.llms
|
import litellm.types.llms
|
||||||
from litellm.types.llms.anthropic import *
|
from litellm.types.llms.anthropic import *
|
||||||
import uuid
|
import uuid
|
||||||
|
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
|
||||||
import litellm.types.llms.vertex_ai
|
import litellm.types.llms.vertex_ai
|
||||||
|
|
||||||
|
|
||||||
|
@ -1460,9 +1453,7 @@ def _load_image_from_url(image_url):
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
except:
|
except:
|
||||||
raise Exception(
|
raise Exception("image conversion failed please run `pip install Pillow`")
|
||||||
"gemini image conversion failed please run `pip install Pillow`"
|
|
||||||
)
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1613,6 +1604,380 @@ def azure_text_pt(messages: list):
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
###### AMAZON BEDROCK #######
|
||||||
|
|
||||||
|
from litellm.types.llms.bedrock import (
|
||||||
|
ToolResultContentBlock as BedrockToolResultContentBlock,
|
||||||
|
ToolResultBlock as BedrockToolResultBlock,
|
||||||
|
ToolConfigBlock as BedrockToolConfigBlock,
|
||||||
|
ToolUseBlock as BedrockToolUseBlock,
|
||||||
|
ImageSourceBlock as BedrockImageSourceBlock,
|
||||||
|
ImageBlock as BedrockImageBlock,
|
||||||
|
ContentBlock as BedrockContentBlock,
|
||||||
|
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
|
||||||
|
ToolSpecBlock as BedrockToolSpecBlock,
|
||||||
|
ToolBlock as BedrockToolBlock,
|
||||||
|
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_details(image_url) -> Tuple[str, str]:
|
||||||
|
try:
|
||||||
|
import base64
|
||||||
|
|
||||||
|
# Send a GET request to the image URL
|
||||||
|
response = requests.get(image_url)
|
||||||
|
response.raise_for_status() # Raise an exception for HTTP errors
|
||||||
|
|
||||||
|
# Check the response's content type to ensure it is an image
|
||||||
|
content_type = response.headers.get("content-type")
|
||||||
|
if not content_type or "image" not in content_type:
|
||||||
|
raise ValueError(
|
||||||
|
f"URL does not point to a valid image (content-type: {content_type})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert the image content to base64 bytes
|
||||||
|
base64_bytes = base64.b64encode(response.content).decode("utf-8")
|
||||||
|
|
||||||
|
# Get mime-type
|
||||||
|
mime_type = content_type.split("/")[
|
||||||
|
1
|
||||||
|
] # Extract mime-type from content-type header
|
||||||
|
|
||||||
|
return base64_bytes, mime_type
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise Exception(f"Request failed: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
|
||||||
|
if "base64" in image_url:
|
||||||
|
# Case 1: Images with base64 encoding
|
||||||
|
import base64, re
|
||||||
|
|
||||||
|
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
||||||
|
image_metadata, img_without_base_64 = image_url.split(",")
|
||||||
|
|
||||||
|
# read mime_type from img_without_base_64=data:image/jpeg;base64
|
||||||
|
# Extract MIME type using regular expression
|
||||||
|
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
|
||||||
|
if mime_type_match:
|
||||||
|
mime_type = mime_type_match.group(1)
|
||||||
|
image_format = mime_type.split("/")[1]
|
||||||
|
else:
|
||||||
|
mime_type = "image/jpeg"
|
||||||
|
image_format = "jpeg"
|
||||||
|
_blob = BedrockImageSourceBlock(bytes=img_without_base_64)
|
||||||
|
supported_image_formats = (
|
||||||
|
litellm.AmazonConverseConfig().get_supported_image_types()
|
||||||
|
)
|
||||||
|
if image_format in supported_image_formats:
|
||||||
|
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
|
||||||
|
else:
|
||||||
|
# Handle the case when the image format is not supported
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported image format: {}. Supported formats: {}".format(
|
||||||
|
image_format, supported_image_formats
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif "https:/" in image_url:
|
||||||
|
# Case 2: Images with direct links
|
||||||
|
image_bytes, image_format = get_image_details(image_url)
|
||||||
|
_blob = BedrockImageSourceBlock(bytes=image_bytes)
|
||||||
|
supported_image_formats = (
|
||||||
|
litellm.AmazonConverseConfig().get_supported_image_types()
|
||||||
|
)
|
||||||
|
if image_format in supported_image_formats:
|
||||||
|
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
|
||||||
|
else:
|
||||||
|
# Handle the case when the image format is not supported
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported image format: {}. Supported formats: {}".format(
|
||||||
|
image_format, supported_image_formats
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported image type. Expected either image url or base64 encoded string - \
|
||||||
|
e.g. 'data:image/jpeg;base64,<base64-encoded-string>'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_bedrock_tool_call_invoke(
|
||||||
|
tool_calls: list,
|
||||||
|
) -> List[BedrockContentBlock]:
|
||||||
|
"""
|
||||||
|
OpenAI tool invokes:
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": null,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_abc123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
Bedrock tool invokes:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"toolUse": {
|
||||||
|
"input": {"location": "Boston, MA", ..},
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"toolUseId": "call_abc123"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
- json.loads argument
|
||||||
|
- extract name
|
||||||
|
- extract id
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
_parts_list: List[BedrockContentBlock] = []
|
||||||
|
for tool in tool_calls:
|
||||||
|
if "function" in tool:
|
||||||
|
id = tool["id"]
|
||||||
|
name = tool["function"].get("name", "")
|
||||||
|
arguments = tool["function"].get("arguments", "")
|
||||||
|
arguments_dict = json.loads(arguments)
|
||||||
|
bedrock_tool = BedrockToolUseBlock(
|
||||||
|
input=arguments_dict, name=name, toolUseId=id
|
||||||
|
)
|
||||||
|
bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool)
|
||||||
|
_parts_list.append(bedrock_content_block)
|
||||||
|
return _parts_list
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
"Unable to convert openai tool calls={} to bedrock tool calls. Received error={}".format(
|
||||||
|
tool_calls, str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_bedrock_tool_call_result(
|
||||||
|
message: dict,
|
||||||
|
) -> BedrockMessageBlock:
|
||||||
|
"""
|
||||||
|
OpenAI message with a tool result looks like:
|
||||||
|
{
|
||||||
|
"tool_call_id": "tool_1",
|
||||||
|
"role": "tool",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"content": "function result goes here",
|
||||||
|
},
|
||||||
|
|
||||||
|
OpenAI message with a function call result looks like:
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"content": "function result goes here",
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
Bedrock result looks like this:
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"toolResult": {
|
||||||
|
"toolUseId": "tooluse_kZJMlvQmRJ6eAyJE5GIl7Q",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"json": {
|
||||||
|
"song": "Elemental Hotel",
|
||||||
|
"artist": "8 Storey Hike"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
-
|
||||||
|
"""
|
||||||
|
content = message.get("content", "")
|
||||||
|
name = message.get("name", "")
|
||||||
|
id = message.get("tool_call_id", str(uuid.uuid4()))
|
||||||
|
|
||||||
|
tool_result_content_block = BedrockToolResultContentBlock(text=content)
|
||||||
|
tool_result = BedrockToolResultBlock(
|
||||||
|
content=[tool_result_content_block],
|
||||||
|
toolUseId=id,
|
||||||
|
)
|
||||||
|
content_block = BedrockContentBlock(toolResult=tool_result)
|
||||||
|
|
||||||
|
return BedrockMessageBlock(role="user", content=[content_block])
|
||||||
|
|
||||||
|
|
||||||
|
def _bedrock_converse_messages_pt(messages: List) -> List[BedrockMessageBlock]:
|
||||||
|
"""
|
||||||
|
Converts given messages from OpenAI format to Bedrock format
|
||||||
|
|
||||||
|
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
|
||||||
|
- Please ensure that function response turn comes immediately after a function call turn
|
||||||
|
"""
|
||||||
|
|
||||||
|
contents: List[BedrockMessageBlock] = []
|
||||||
|
msg_i = 0
|
||||||
|
while msg_i < len(messages):
|
||||||
|
user_content: List[BedrockContentBlock] = []
|
||||||
|
init_msg_i = msg_i
|
||||||
|
## MERGE CONSECUTIVE USER CONTENT ##
|
||||||
|
while msg_i < len(messages) and messages[msg_i]["role"] == "user":
|
||||||
|
if isinstance(messages[msg_i]["content"], list):
|
||||||
|
_parts: List[BedrockContentBlock] = []
|
||||||
|
for element in messages[msg_i]["content"]:
|
||||||
|
if isinstance(element, dict):
|
||||||
|
if element["type"] == "text":
|
||||||
|
_part = BedrockContentBlock(text=element["text"])
|
||||||
|
_parts.append(_part)
|
||||||
|
elif element["type"] == "image_url":
|
||||||
|
image_url = element["image_url"]["url"]
|
||||||
|
_part = _process_bedrock_converse_image_block( # type: ignore
|
||||||
|
image_url=image_url
|
||||||
|
)
|
||||||
|
_parts.append(BedrockContentBlock(image=_part)) # type: ignore
|
||||||
|
user_content.extend(_parts)
|
||||||
|
else:
|
||||||
|
_part = BedrockContentBlock(text=messages[msg_i]["content"])
|
||||||
|
user_content.append(_part)
|
||||||
|
|
||||||
|
msg_i += 1
|
||||||
|
|
||||||
|
if user_content:
|
||||||
|
contents.append(BedrockMessageBlock(role="user", content=user_content))
|
||||||
|
assistant_content: List[BedrockContentBlock] = []
|
||||||
|
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||||
|
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||||
|
if isinstance(messages[msg_i]["content"], list):
|
||||||
|
assistants_parts: List[BedrockContentBlock] = []
|
||||||
|
for element in messages[msg_i]["content"]:
|
||||||
|
if isinstance(element, dict):
|
||||||
|
if element["type"] == "text":
|
||||||
|
assistants_part = BedrockContentBlock(text=element["text"])
|
||||||
|
assistants_parts.append(assistants_part)
|
||||||
|
elif element["type"] == "image_url":
|
||||||
|
image_url = element["image_url"]["url"]
|
||||||
|
assistants_part = _process_bedrock_converse_image_block( # type: ignore
|
||||||
|
image_url=image_url
|
||||||
|
)
|
||||||
|
assistants_parts.append(
|
||||||
|
BedrockContentBlock(image=assistants_part) # type: ignore
|
||||||
|
)
|
||||||
|
assistant_content.extend(assistants_parts)
|
||||||
|
elif messages[msg_i].get(
|
||||||
|
"tool_calls", []
|
||||||
|
): # support assistant tool invoke convertion
|
||||||
|
assistant_content.extend(
|
||||||
|
_convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assistant_text = (
|
||||||
|
messages[msg_i].get("content") or ""
|
||||||
|
) # either string or none
|
||||||
|
if assistant_text:
|
||||||
|
assistant_content.append(BedrockContentBlock(text=assistant_text))
|
||||||
|
|
||||||
|
msg_i += 1
|
||||||
|
|
||||||
|
if assistant_content:
|
||||||
|
contents.append(
|
||||||
|
BedrockMessageBlock(role="assistant", content=assistant_content)
|
||||||
|
)
|
||||||
|
|
||||||
|
## APPEND TOOL CALL MESSAGES ##
|
||||||
|
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
|
||||||
|
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
|
||||||
|
contents.append(tool_call_result)
|
||||||
|
msg_i += 1
|
||||||
|
if msg_i == init_msg_i: # prevent infinite loops
|
||||||
|
raise Exception(
|
||||||
|
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||||
|
messages[msg_i]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return contents
|
||||||
|
|
||||||
|
|
||||||
|
def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
|
||||||
|
"""
|
||||||
|
OpenAI tools looks like:
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
Bedrock toolConfig looks like:
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"toolSpec": {
|
||||||
|
"name": "top_song",
|
||||||
|
"description": "Get the most popular song played on a radio station.",
|
||||||
|
"inputSchema": {
|
||||||
|
"json": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sign": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"sign"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
tool_block_list: List[BedrockToolBlock] = []
|
||||||
|
for tool in tools:
|
||||||
|
parameters = tool.get("function", {}).get("parameters", None)
|
||||||
|
name = tool.get("function", {}).get("name", "")
|
||||||
|
description = tool.get("function", {}).get("description", "")
|
||||||
|
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
|
||||||
|
tool_spec = BedrockToolSpecBlock(
|
||||||
|
inputSchema=tool_input_schema, name=name, description=description
|
||||||
|
)
|
||||||
|
tool_block = BedrockToolBlock(toolSpec=tool_spec)
|
||||||
|
tool_block_list.append(tool_block)
|
||||||
|
|
||||||
|
return tool_block_list
|
||||||
|
|
||||||
|
|
||||||
# Function call template
|
# Function call template
|
||||||
def function_call_prompt(messages: list, functions: list):
|
def function_call_prompt(messages: list, functions: list):
|
||||||
function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:"""
|
function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:"""
|
||||||
|
|
|
@ -12,6 +12,7 @@ from litellm.llms.prompt_templates.factory import (
|
||||||
convert_to_gemini_tool_call_result,
|
convert_to_gemini_tool_call_result,
|
||||||
convert_to_gemini_tool_call_invoke,
|
convert_to_gemini_tool_call_invoke,
|
||||||
)
|
)
|
||||||
|
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
@ -297,29 +298,31 @@ def _convert_gemini_role(role: str) -> Literal["user", "model"]:
|
||||||
|
|
||||||
def _process_gemini_image(image_url: str) -> PartType:
|
def _process_gemini_image(image_url: str) -> PartType:
|
||||||
try:
|
try:
|
||||||
if ".mp4" in image_url and "gs://" in image_url:
|
# GCS URIs
|
||||||
# Case 1: Videos with Cloud Storage URIs
|
if "gs://" in image_url:
|
||||||
part_mime = "video/mp4"
|
# Figure out file type
|
||||||
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
|
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
|
||||||
return PartType(file_data=_file_data)
|
extension = extension_with_dot[1:] # Ex: "png"
|
||||||
elif ".pdf" in image_url and "gs://" in image_url:
|
|
||||||
# Case 2: PDF's with Cloud Storage URIs
|
file_type = get_file_type_from_extension(extension)
|
||||||
part_mime = "application/pdf"
|
|
||||||
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
|
# Validate the file type is supported by Gemini
|
||||||
return PartType(file_data=_file_data)
|
if not is_gemini_1_5_accepted_file_type(file_type):
|
||||||
elif "gs://" in image_url:
|
raise Exception(f"File type not supported by gemini - {file_type}")
|
||||||
# Case 3: Images with Cloud Storage URIs
|
|
||||||
# The supported MIME types for images include image/png and image/jpeg.
|
mime_type = get_file_mime_type_for_file_type(file_type)
|
||||||
part_mime = "image/png" if "png" in image_url else "image/jpeg"
|
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
|
||||||
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
|
|
||||||
return PartType(file_data=_file_data)
|
return PartType(file_data=file_data)
|
||||||
|
|
||||||
|
# Direct links
|
||||||
elif "https:/" in image_url:
|
elif "https:/" in image_url:
|
||||||
# Case 4: Images with direct links
|
|
||||||
image = _load_image_from_url(image_url)
|
image = _load_image_from_url(image_url)
|
||||||
_blob = BlobType(data=image.data, mime_type=image._mime_type)
|
_blob = BlobType(data=image.data, mime_type=image._mime_type)
|
||||||
return PartType(inline_data=_blob)
|
return PartType(inline_data=_blob)
|
||||||
|
|
||||||
|
# Base64 encoding
|
||||||
elif "base64" in image_url:
|
elif "base64" in image_url:
|
||||||
# Case 5: Images with base64 encoding
|
|
||||||
import base64, re
|
import base64, re
|
||||||
|
|
||||||
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
||||||
|
@ -426,112 +429,6 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
|
||||||
return contents
|
return contents
|
||||||
|
|
||||||
|
|
||||||
def _gemini_vision_convert_messages(messages: list):
|
|
||||||
"""
|
|
||||||
Converts given messages for GPT-4 Vision to Gemini format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type:
|
|
||||||
- If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt.
|
|
||||||
- If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed.
|
|
||||||
Exception: If any other exception occurs during the execution of the function.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
This function is based on the code from the 'gemini/getting-started/intro_gemini_python.ipynb' notebook in the 'generative-ai' repository on GitHub.
|
|
||||||
The supported MIME types for images include 'image/png' and 'image/jpeg'.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> messages = [
|
|
||||||
... {"content": "Hello, world!"},
|
|
||||||
... {"content": [{"type": "text", "text": "This is a text message."}, {"type": "image_url", "image_url": "example.com/image.png"}]},
|
|
||||||
... ]
|
|
||||||
>>> _gemini_vision_convert_messages(messages)
|
|
||||||
('Hello, world!This is a text message.', [<Part object>, <Part object>])
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import vertexai
|
|
||||||
except:
|
|
||||||
raise VertexAIError(
|
|
||||||
status_code=400,
|
|
||||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
from vertexai.preview.language_models import (
|
|
||||||
ChatModel,
|
|
||||||
CodeChatModel,
|
|
||||||
InputOutputTextPair,
|
|
||||||
)
|
|
||||||
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
|
|
||||||
from vertexai.preview.generative_models import (
|
|
||||||
GenerativeModel,
|
|
||||||
Part,
|
|
||||||
GenerationConfig,
|
|
||||||
Image,
|
|
||||||
)
|
|
||||||
|
|
||||||
# given messages for gpt-4 vision, convert them for gemini
|
|
||||||
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
|
|
||||||
prompt = ""
|
|
||||||
images = []
|
|
||||||
for message in messages:
|
|
||||||
if isinstance(message["content"], str):
|
|
||||||
prompt += message["content"]
|
|
||||||
elif isinstance(message["content"], list):
|
|
||||||
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
|
||||||
for element in message["content"]:
|
|
||||||
if isinstance(element, dict):
|
|
||||||
if element["type"] == "text":
|
|
||||||
prompt += element["text"]
|
|
||||||
elif element["type"] == "image_url":
|
|
||||||
image_url = element["image_url"]["url"]
|
|
||||||
images.append(image_url)
|
|
||||||
# processing images passed to gemini
|
|
||||||
processed_images = []
|
|
||||||
for img in images:
|
|
||||||
if "gs://" in img:
|
|
||||||
# Case 1: Images with Cloud Storage URIs
|
|
||||||
# The supported MIME types for images include image/png and image/jpeg.
|
|
||||||
part_mime = "image/png" if "png" in img else "image/jpeg"
|
|
||||||
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
|
|
||||||
processed_images.append(google_clooud_part)
|
|
||||||
elif "https:/" in img:
|
|
||||||
# Case 2: Images with direct links
|
|
||||||
image = _load_image_from_url(img)
|
|
||||||
processed_images.append(image)
|
|
||||||
elif ".mp4" in img and "gs://" in img:
|
|
||||||
# Case 3: Videos with Cloud Storage URIs
|
|
||||||
part_mime = "video/mp4"
|
|
||||||
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
|
|
||||||
processed_images.append(google_clooud_part)
|
|
||||||
elif "base64" in img:
|
|
||||||
# Case 4: Images with base64 encoding
|
|
||||||
import base64, re
|
|
||||||
|
|
||||||
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
|
||||||
image_metadata, img_without_base_64 = img.split(",")
|
|
||||||
|
|
||||||
# read mime_type from img_without_base_64=data:image/jpeg;base64
|
|
||||||
# Extract MIME type using regular expression
|
|
||||||
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
|
|
||||||
|
|
||||||
if mime_type_match:
|
|
||||||
mime_type = mime_type_match.group(1)
|
|
||||||
else:
|
|
||||||
mime_type = "image/jpeg"
|
|
||||||
decoded_img = base64.b64decode(img_without_base_64)
|
|
||||||
processed_image = Part.from_data(data=decoded_img, mime_type=mime_type)
|
|
||||||
processed_images.append(processed_image)
|
|
||||||
return prompt, processed_images
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
|
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
|
||||||
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
|
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
|
||||||
return _cache_key
|
return _cache_key
|
||||||
|
|
|
@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion
|
||||||
from .llms.anthropic_text import AnthropicTextCompletion
|
from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
from .llms.bedrock_httpx import BedrockLLM
|
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
|
||||||
from .llms.vertex_httpx import VertexLLM
|
from .llms.vertex_httpx import VertexLLM
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
|
@ -122,6 +122,7 @@ huggingface = Huggingface()
|
||||||
predibase_chat_completions = PredibaseChatCompletion()
|
predibase_chat_completions = PredibaseChatCompletion()
|
||||||
triton_chat_completions = TritonChatCompletion()
|
triton_chat_completions = TritonChatCompletion()
|
||||||
bedrock_chat_completion = BedrockLLM()
|
bedrock_chat_completion = BedrockLLM()
|
||||||
|
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||||
vertex_chat_completion = VertexLLM()
|
vertex_chat_completion = VertexLLM()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
@ -2107,22 +2108,40 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = bedrock_chat_completion.completion(
|
if model.startswith("anthropic"):
|
||||||
model=model,
|
response = bedrock_converse_chat_completion.completion(
|
||||||
messages=messages,
|
model=model,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
messages=messages,
|
||||||
model_response=model_response,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
print_verbose=print_verbose,
|
model_response=model_response,
|
||||||
optional_params=optional_params,
|
print_verbose=print_verbose,
|
||||||
litellm_params=litellm_params,
|
optional_params=optional_params,
|
||||||
logger_fn=logger_fn,
|
litellm_params=litellm_params,
|
||||||
encoding=encoding,
|
logger_fn=logger_fn,
|
||||||
logging_obj=logging,
|
encoding=encoding,
|
||||||
extra_headers=extra_headers,
|
logging_obj=logging,
|
||||||
timeout=timeout,
|
extra_headers=extra_headers,
|
||||||
acompletion=acompletion,
|
timeout=timeout,
|
||||||
client=client,
|
acompletion=acompletion,
|
||||||
)
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = bedrock_chat_completion.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
timeout=timeout,
|
||||||
|
acompletion=acompletion,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
|
|
@ -1,11 +1,20 @@
|
||||||
from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict
|
from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
import enum
|
import enum
|
||||||
from typing import Optional, List, Union, Dict, Literal, Any, TypedDict
|
from typing import Optional, List, Union, Dict, Literal, Any, TypedDict, TYPE_CHECKING
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import uuid, json, sys, os
|
import uuid, json, sys, os
|
||||||
from litellm.types.router import UpdateRouterConfig
|
from litellm.types.router import UpdateRouterConfig
|
||||||
from litellm.types.utils import ProviderField
|
from litellm.types.utils import ProviderField
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
|
Span = _Span
|
||||||
|
else:
|
||||||
|
Span = Any
|
||||||
|
|
||||||
|
|
||||||
class LitellmUserRoles(str, enum.Enum):
|
class LitellmUserRoles(str, enum.Enum):
|
||||||
|
@ -1195,6 +1204,7 @@ class UserAPIKeyAuth(
|
||||||
]
|
]
|
||||||
] = None
|
] = None
|
||||||
allowed_model_region: Optional[Literal["eu"]] = None
|
allowed_model_region: Optional[Literal["eu"]] = None
|
||||||
|
parent_otel_span: Optional[Span] = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1207,6 +1217,9 @@ class UserAPIKeyAuth(
|
||||||
values.update({"api_key": hash_token(values.get("api_key"))})
|
values.update({"api_key": hash_token(values.get("api_key"))})
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM_Config(LiteLLMBase):
|
class LiteLLM_Config(LiteLLMBase):
|
||||||
param_name: str
|
param_name: str
|
||||||
|
|
|
@ -17,10 +17,19 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_OrganizationTable,
|
LiteLLM_OrganizationTable,
|
||||||
LitellmUserRoles,
|
LitellmUserRoles,
|
||||||
)
|
)
|
||||||
from typing import Optional, Literal, Union
|
from typing import Optional, Literal, TYPE_CHECKING, Any
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
|
Span = _Span
|
||||||
|
else:
|
||||||
|
Span = Any
|
||||||
|
|
||||||
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
||||||
|
|
||||||
|
@ -216,10 +225,13 @@ def get_actual_routes(allowed_routes: list) -> list:
|
||||||
return actual_routes
|
return actual_routes
|
||||||
|
|
||||||
|
|
||||||
|
@log_to_opentelemetry
|
||||||
async def get_end_user_object(
|
async def get_end_user_object(
|
||||||
end_user_id: Optional[str],
|
end_user_id: Optional[str],
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
user_api_key_cache: DualCache,
|
user_api_key_cache: DualCache,
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
|
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||||
) -> Optional[LiteLLM_EndUserTable]:
|
) -> Optional[LiteLLM_EndUserTable]:
|
||||||
"""
|
"""
|
||||||
Returns end user object, if in db.
|
Returns end user object, if in db.
|
||||||
|
@ -279,11 +291,14 @@ async def get_end_user_object(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@log_to_opentelemetry
|
||||||
async def get_user_object(
|
async def get_user_object(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
user_api_key_cache: DualCache,
|
user_api_key_cache: DualCache,
|
||||||
user_id_upsert: bool,
|
user_id_upsert: bool,
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
|
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||||
) -> Optional[LiteLLM_UserTable]:
|
) -> Optional[LiteLLM_UserTable]:
|
||||||
"""
|
"""
|
||||||
- Check if user id in proxy User Table
|
- Check if user id in proxy User Table
|
||||||
|
@ -330,10 +345,13 @@ async def get_user_object(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@log_to_opentelemetry
|
||||||
async def get_team_object(
|
async def get_team_object(
|
||||||
team_id: str,
|
team_id: str,
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
user_api_key_cache: DualCache,
|
user_api_key_cache: DualCache,
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
|
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||||
) -> LiteLLM_TeamTable:
|
) -> LiteLLM_TeamTable:
|
||||||
"""
|
"""
|
||||||
- Check if team id in proxy Team Table
|
- Check if team id in proxy Team Table
|
||||||
|
@ -372,10 +390,13 @@ async def get_team_object(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@log_to_opentelemetry
|
||||||
async def get_org_object(
|
async def get_org_object(
|
||||||
org_id: str,
|
org_id: str,
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
user_api_key_cache: DualCache,
|
user_api_key_cache: DualCache,
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
|
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
- Check if org id in proxy Org Table
|
- Check if org id in proxy Org Table
|
||||||
|
|
130
litellm/proxy/litellm_pre_call_utils.py
Normal file
130
litellm/proxy/litellm_pre_call_utils.py
Normal file
|
@ -0,0 +1,130 @@
|
||||||
|
import copy
|
||||||
|
from fastapi import Request
|
||||||
|
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm._logging import verbose_proxy_logger, verbose_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
|
||||||
|
|
||||||
|
ProxyConfig = _ProxyConfig
|
||||||
|
else:
|
||||||
|
ProxyConfig = Any
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cache_control(cache_control):
|
||||||
|
cache_dict = {}
|
||||||
|
directives = cache_control.split(", ")
|
||||||
|
|
||||||
|
for directive in directives:
|
||||||
|
if "=" in directive:
|
||||||
|
key, value = directive.split("=")
|
||||||
|
cache_dict[key] = value
|
||||||
|
else:
|
||||||
|
cache_dict[directive] = True
|
||||||
|
|
||||||
|
return cache_dict
|
||||||
|
|
||||||
|
|
||||||
|
async def add_litellm_data_to_request(
|
||||||
|
data: dict,
|
||||||
|
request: Request,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
proxy_config: ProxyConfig,
|
||||||
|
general_settings: Optional[Dict[str, Any]] = None,
|
||||||
|
version: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adds LiteLLM-specific data to the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (dict): The data dictionary to be modified.
|
||||||
|
request (Request): The incoming request.
|
||||||
|
user_api_key_dict (UserAPIKeyAuth): The user API key dictionary.
|
||||||
|
general_settings (Optional[Dict[str, Any]], optional): General settings. Defaults to None.
|
||||||
|
version (Optional[str], optional): Version. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The modified data dictionary.
|
||||||
|
|
||||||
|
"""
|
||||||
|
query_params = dict(request.query_params)
|
||||||
|
if "api-version" in query_params:
|
||||||
|
data["api_version"] = query_params["api-version"]
|
||||||
|
|
||||||
|
# Include original request and headers in the data
|
||||||
|
data["proxy_server_request"] = {
|
||||||
|
"url": str(request.url),
|
||||||
|
"method": request.method,
|
||||||
|
"headers": dict(request.headers),
|
||||||
|
"body": copy.copy(data), # use copy instead of deepcopy
|
||||||
|
}
|
||||||
|
|
||||||
|
## Cache Controls
|
||||||
|
headers = request.headers
|
||||||
|
verbose_proxy_logger.debug("Request Headers: %s", headers)
|
||||||
|
cache_control_header = headers.get("Cache-Control", None)
|
||||||
|
if cache_control_header:
|
||||||
|
cache_dict = parse_cache_control(cache_control_header)
|
||||||
|
data["ttl"] = cache_dict.get("s-maxage")
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("receiving data: %s", data)
|
||||||
|
# users can pass in 'user' param to /chat/completions. Don't override it
|
||||||
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||||
|
# if users are using user_api_key_auth, set `user` in `data`
|
||||||
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
|
if "metadata" not in data:
|
||||||
|
data["metadata"] = {}
|
||||||
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["user_api_key_alias"] = getattr(
|
||||||
|
user_api_key_dict, "key_alias", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_end_user_max_budget"] = getattr(
|
||||||
|
user_api_key_dict, "end_user_max_budget", None
|
||||||
|
)
|
||||||
|
data["metadata"]["litellm_api_version"] = version
|
||||||
|
|
||||||
|
if general_settings is not None:
|
||||||
|
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
|
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
|
||||||
|
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||||
|
user_api_key_dict, "team_id", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_key_team_alias"] = getattr(
|
||||||
|
user_api_key_dict, "team_alias", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
|
_headers = dict(request.headers)
|
||||||
|
_headers.pop(
|
||||||
|
"authorization", None
|
||||||
|
) # do not store the original `sk-..` api key in the db
|
||||||
|
data["metadata"]["headers"] = _headers
|
||||||
|
data["metadata"]["endpoint"] = str(request.url)
|
||||||
|
# Add the OTEL Parent Trace before sending it LiteLLM
|
||||||
|
data["metadata"]["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span
|
||||||
|
|
||||||
|
### END-USER SPECIFIC PARAMS ###
|
||||||
|
if user_api_key_dict.allowed_model_region is not None:
|
||||||
|
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
|
||||||
|
|
||||||
|
### TEAM-SPECIFIC PARAMS ###
|
||||||
|
if user_api_key_dict.team_id is not None:
|
||||||
|
team_config = await proxy_config.load_team_config(
|
||||||
|
team_id=user_api_key_dict.team_id
|
||||||
|
)
|
||||||
|
if len(team_config) == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
team_id = team_config.pop("team_id", None)
|
||||||
|
data["metadata"]["team_id"] = team_id
|
||||||
|
data = {
|
||||||
|
**team_config,
|
||||||
|
**data,
|
||||||
|
} # add the team-specific configs to the completion call
|
||||||
|
|
||||||
|
return data
|
|
@ -21,10 +21,14 @@ model_list:
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
|
alerting: ["slack"]
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["otel"]
|
||||||
|
store_audit_logs: true
|
||||||
|
redact_messages_in_exceptions: True
|
||||||
enforced_params:
|
enforced_params:
|
||||||
- user
|
- user
|
||||||
- metadata
|
- metadata
|
||||||
- metadata.generation_name
|
- metadata.generation_name
|
||||||
|
|
||||||
litellm_settings:
|
|
||||||
store_audit_logs: true
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, List, Any, Literal, Union
|
from typing import Optional, List, Any, Literal, Union, TYPE_CHECKING
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import hashlib
|
import hashlib
|
||||||
|
@ -46,6 +46,15 @@ from email.mime.text import MIMEText
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from litellm.integrations.slack_alerting import SlackAlerting
|
from litellm.integrations.slack_alerting import SlackAlerting
|
||||||
|
from typing_extensions import overload
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
|
Span = _Span
|
||||||
|
else:
|
||||||
|
Span = Any
|
||||||
|
|
||||||
|
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
|
@ -63,6 +72,58 @@ def print_verbose(print_statement):
|
||||||
print(f"LiteLLM Proxy: {print_statement}") # noqa
|
print(f"LiteLLM Proxy: {print_statement}") # noqa
|
||||||
|
|
||||||
|
|
||||||
|
def safe_deep_copy(data):
|
||||||
|
"""
|
||||||
|
Safe Deep Copy
|
||||||
|
|
||||||
|
The LiteLLM Request has some object that can-not be pickled / deep copied
|
||||||
|
|
||||||
|
Use this function to safely deep copy the LiteLLM Request
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Step 1: Remove the litellm_parent_otel_span
|
||||||
|
if isinstance(data, dict):
|
||||||
|
# remove litellm_parent_otel_span since this is not picklable
|
||||||
|
if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]:
|
||||||
|
litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span")
|
||||||
|
new_data = copy.deepcopy(data)
|
||||||
|
|
||||||
|
# Step 2: re-add the litellm_parent_otel_span after doing a deep copy
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if "metadata" in data:
|
||||||
|
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span
|
||||||
|
return new_data
|
||||||
|
|
||||||
|
|
||||||
|
def log_to_opentelemetry(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
start_time = datetime.now()
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
end_time = datetime.now()
|
||||||
|
|
||||||
|
# Log to OTEL only if "parent_otel_span" is in kwargs and is not None
|
||||||
|
if (
|
||||||
|
"parent_otel_span" in kwargs
|
||||||
|
and kwargs["parent_otel_span"] is not None
|
||||||
|
and "proxy_logging_obj" in kwargs
|
||||||
|
and kwargs["proxy_logging_obj"] is not None
|
||||||
|
):
|
||||||
|
proxy_logging_obj = kwargs["proxy_logging_obj"]
|
||||||
|
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||||
|
service=ServiceTypes.DB,
|
||||||
|
call_type=func.__name__,
|
||||||
|
parent_otel_span=kwargs["parent_otel_span"],
|
||||||
|
duration=0.0,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
# end of logging to otel
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
### LOGGING ###
|
### LOGGING ###
|
||||||
class ProxyLogging:
|
class ProxyLogging:
|
||||||
"""
|
"""
|
||||||
|
@ -282,7 +343,7 @@ class ProxyLogging:
|
||||||
"""
|
"""
|
||||||
Runs the CustomLogger's async_moderation_hook()
|
Runs the CustomLogger's async_moderation_hook()
|
||||||
"""
|
"""
|
||||||
new_data = copy.deepcopy(data)
|
new_data = safe_deep_copy(data)
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
try:
|
try:
|
||||||
if isinstance(callback, CustomLogger):
|
if isinstance(callback, CustomLogger):
|
||||||
|
@ -832,6 +893,7 @@ class PrismaClient:
|
||||||
max_time=10, # maximum total time to retry for
|
max_time=10, # maximum total time to retry for
|
||||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||||
)
|
)
|
||||||
|
@log_to_opentelemetry
|
||||||
async def get_data(
|
async def get_data(
|
||||||
self,
|
self,
|
||||||
token: Optional[Union[str, list]] = None,
|
token: Optional[Union[str, list]] = None,
|
||||||
|
@ -858,6 +920,8 @@ class PrismaClient:
|
||||||
limit: Optional[
|
limit: Optional[
|
||||||
int
|
int
|
||||||
] = None, # pagination, number of rows to getch when find_all==True
|
] = None, # pagination, number of rows to getch when find_all==True
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
|
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||||
):
|
):
|
||||||
args_passed_in = locals()
|
args_passed_in = locals()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -2829,6 +2893,10 @@ missing_keys_html_form = """
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _to_ns(dt):
|
||||||
|
return int(dt.timestamp() * 1e9)
|
||||||
|
|
||||||
|
|
||||||
def get_error_message_str(e: Exception) -> str:
|
def get_error_message_str(e: Exception) -> str:
|
||||||
error_message = ""
|
error_message = ""
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
|
|
2
litellm/py.typed
Normal file
2
litellm/py.typed
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
# Marker file to instruct type checkers to look for inline type annotations in this package.
|
||||||
|
# See PEP 561 for more information.
|
File diff suppressed because it is too large
Load diff
|
@ -243,6 +243,7 @@ def test_completion_bedrock_claude_sts_oidc_auth():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
|
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
|
||||||
reason="Cannot run without being in CircleCI Runner",
|
reason="Cannot run without being in CircleCI Runner",
|
||||||
|
@ -277,7 +278,15 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
def test_bedrock_claude_3():
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"image_url",
|
||||||
|
[
|
||||||
|
"",
|
||||||
|
"https://avatars.githubusercontent.com/u/29436595?v=",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_bedrock_claude_3(image_url):
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
data = {
|
data = {
|
||||||
|
@ -294,7 +303,7 @@ def test_bedrock_claude_3():
|
||||||
{
|
{
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"detail": "high",
|
"detail": "high",
|
||||||
"url": "",
|
"url": image_url,
|
||||||
},
|
},
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
},
|
},
|
||||||
|
@ -313,7 +322,6 @@ def test_bedrock_claude_3():
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
assert len(response.choices) > 0
|
assert len(response.choices) > 0
|
||||||
assert len(response.choices[0].message.content) > 0
|
assert len(response.choices[0].message.content) > 0
|
||||||
|
|
||||||
except RateLimitError:
|
except RateLimitError:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -552,7 +560,7 @@ def test_bedrock_ptu():
|
||||||
assert "url" in mock_client_post.call_args.kwargs
|
assert "url" in mock_client_post.call_args.kwargs
|
||||||
assert (
|
assert (
|
||||||
mock_client_post.call_args.kwargs["url"]
|
mock_client_post.call_args.kwargs["url"]
|
||||||
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke"
|
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/converse"
|
||||||
)
|
)
|
||||||
mock_client_post.assert_called_once()
|
mock_client_post.assert_called_once()
|
||||||
|
|
||||||
|
|
|
@ -300,7 +300,11 @@ def test_completion_claude_3():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_claude_3_function_call():
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||||
|
)
|
||||||
|
def test_completion_claude_3_function_call(model):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
@ -331,13 +335,14 @@ def test_completion_claude_3_function_call():
|
||||||
try:
|
try:
|
||||||
# test without max tokens
|
# test without max tokens
|
||||||
response = completion(
|
response = completion(
|
||||||
model="anthropic/claude-3-opus-20240229",
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice={
|
tool_choice={
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {"name": "get_current_weather"},
|
"function": {"name": "get_current_weather"},
|
||||||
},
|
},
|
||||||
|
drop_params=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add any assertions, here to check response args
|
# Add any assertions, here to check response args
|
||||||
|
@ -364,10 +369,11 @@ def test_completion_claude_3_function_call():
|
||||||
)
|
)
|
||||||
# In the second response, Claude should deduce answer from tool results
|
# In the second response, Claude should deduce answer from tool results
|
||||||
second_response = completion(
|
second_response = completion(
|
||||||
model="anthropic/claude-3-opus-20240229",
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
|
drop_params=True,
|
||||||
)
|
)
|
||||||
print(second_response)
|
print(second_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1398,7 +1404,6 @@ def test_hf_test_completion_tgi():
|
||||||
|
|
||||||
|
|
||||||
def mock_post(url, data=None, json=None, headers=None):
|
def mock_post(url, data=None, json=None, headers=None):
|
||||||
|
|
||||||
print(f"url={url}")
|
print(f"url={url}")
|
||||||
if "text-classification" in url:
|
if "text-classification" in url:
|
||||||
raise Exception("Model not found")
|
raise Exception("Model not found")
|
||||||
|
@ -2241,9 +2246,6 @@ def test_re_use_openaiClient():
|
||||||
pytest.fail("got Exception", e)
|
pytest.fail("got Exception", e)
|
||||||
|
|
||||||
|
|
||||||
# test_re_use_openaiClient()
|
|
||||||
|
|
||||||
|
|
||||||
def test_completion_azure():
|
def test_completion_azure():
|
||||||
try:
|
try:
|
||||||
print("azure gpt-3.5 test\n\n")
|
print("azure gpt-3.5 test\n\n")
|
||||||
|
|
|
@ -15,6 +15,7 @@ from litellm.llms.prompt_templates.factory import (
|
||||||
claude_2_1_pt,
|
claude_2_1_pt,
|
||||||
llama_2_chat_pt,
|
llama_2_chat_pt,
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
|
_bedrock_tools_pt,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,3 +129,27 @@ def test_anthropic_messages_pt():
|
||||||
|
|
||||||
|
|
||||||
# codellama_prompt_format()
|
# codellama_prompt_format()
|
||||||
|
def test_bedrock_tool_calling_pt():
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
converted_tools = _bedrock_tools_pt(tools=tools)
|
||||||
|
|
||||||
|
print(converted_tools)
|
||||||
|
|
|
@ -210,7 +210,9 @@ def test_chat_completion_exception_any_model(client):
|
||||||
)
|
)
|
||||||
assert isinstance(openai_exception, openai.BadRequestError)
|
assert isinstance(openai_exception, openai.BadRequestError)
|
||||||
_error_message = openai_exception.message
|
_error_message = openai_exception.message
|
||||||
assert "chat_completion: Invalid model name passed in model=Lite-GPT-12" in str(_error_message)
|
assert "chat_completion: Invalid model name passed in model=Lite-GPT-12" in str(
|
||||||
|
_error_message
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
@ -238,7 +240,9 @@ def test_embedding_exception_any_model(client):
|
||||||
print("Exception raised=", openai_exception)
|
print("Exception raised=", openai_exception)
|
||||||
assert isinstance(openai_exception, openai.BadRequestError)
|
assert isinstance(openai_exception, openai.BadRequestError)
|
||||||
_error_message = openai_exception.message
|
_error_message = openai_exception.message
|
||||||
assert "embeddings: Invalid model name passed in model=Lite-GPT-12" in str(_error_message)
|
assert "embeddings: Invalid model name passed in model=Lite-GPT-12" in str(
|
||||||
|
_error_message
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
|
|
@ -1284,18 +1284,18 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True]) # False
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
# "bedrock/cohere.command-r-plus-v1:0",
|
"bedrock/cohere.command-r-plus-v1:0",
|
||||||
# "anthropic.claude-3-sonnet-20240229-v1:0",
|
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
# "anthropic.claude-instant-v1",
|
"anthropic.claude-instant-v1",
|
||||||
# "bedrock/ai21.j2-mid",
|
"bedrock/ai21.j2-mid",
|
||||||
# "mistral.mistral-7b-instruct-v0:2",
|
"mistral.mistral-7b-instruct-v0:2",
|
||||||
# "bedrock/amazon.titan-tg1-large",
|
"bedrock/amazon.titan-tg1-large",
|
||||||
# "meta.llama3-8b-instruct-v1:0",
|
"meta.llama3-8b-instruct-v1:0",
|
||||||
"cohere.command-text-v14"
|
"cohere.command-text-v14",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
281
litellm/types/files.py
Normal file
281
litellm/types/files.py
Normal file
|
@ -0,0 +1,281 @@
|
||||||
|
from enum import Enum
|
||||||
|
from types import MappingProxyType
|
||||||
|
from typing import List, Set
|
||||||
|
|
||||||
|
"""
|
||||||
|
Base Enums/Consts
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class FileType(Enum):
|
||||||
|
AAC = "AAC"
|
||||||
|
CSV = "CSV"
|
||||||
|
DOC = "DOC"
|
||||||
|
DOCX = "DOCX"
|
||||||
|
FLAC = "FLAC"
|
||||||
|
FLV = "FLV"
|
||||||
|
GIF = "GIF"
|
||||||
|
GOOGLE_DOC = "GOOGLE_DOC"
|
||||||
|
GOOGLE_DRAWINGS = "GOOGLE_DRAWINGS"
|
||||||
|
GOOGLE_SHEETS = "GOOGLE_SHEETS"
|
||||||
|
GOOGLE_SLIDES = "GOOGLE_SLIDES"
|
||||||
|
HEIC = "HEIC"
|
||||||
|
HEIF = "HEIF"
|
||||||
|
HTML = "HTML"
|
||||||
|
JPEG = "JPEG"
|
||||||
|
JSON = "JSON"
|
||||||
|
M4A = "M4A"
|
||||||
|
M4V = "M4V"
|
||||||
|
MOV = "MOV"
|
||||||
|
MP3 = "MP3"
|
||||||
|
MP4 = "MP4"
|
||||||
|
MPEG = "MPEG"
|
||||||
|
MPEGPS = "MPEGPS"
|
||||||
|
MPG = "MPG"
|
||||||
|
MPA = "MPA"
|
||||||
|
MPGA = "MPGA"
|
||||||
|
OGG = "OGG"
|
||||||
|
OPUS = "OPUS"
|
||||||
|
PDF = "PDF"
|
||||||
|
PCM = "PCM"
|
||||||
|
PNG = "PNG"
|
||||||
|
PPT = "PPT"
|
||||||
|
PPTX = "PPTX"
|
||||||
|
RTF = "RTF"
|
||||||
|
THREE_GPP = "3GPP"
|
||||||
|
TXT = "TXT"
|
||||||
|
WAV = "WAV"
|
||||||
|
WEBM = "WEBM"
|
||||||
|
WEBP = "WEBP"
|
||||||
|
WMV = "WMV"
|
||||||
|
XLS = "XLS"
|
||||||
|
XLSX = "XLSX"
|
||||||
|
|
||||||
|
|
||||||
|
FILE_EXTENSIONS: MappingProxyType[FileType, List[str]] = MappingProxyType(
|
||||||
|
{
|
||||||
|
FileType.AAC: ["aac"],
|
||||||
|
FileType.CSV: ["csv"],
|
||||||
|
FileType.DOC: ["doc"],
|
||||||
|
FileType.DOCX: ["docx"],
|
||||||
|
FileType.FLAC: ["flac"],
|
||||||
|
FileType.FLV: ["flv"],
|
||||||
|
FileType.GIF: ["gif"],
|
||||||
|
FileType.GOOGLE_DOC: ["gdoc"],
|
||||||
|
FileType.GOOGLE_DRAWINGS: ["gdraw"],
|
||||||
|
FileType.GOOGLE_SHEETS: ["gsheet"],
|
||||||
|
FileType.GOOGLE_SLIDES: ["gslides"],
|
||||||
|
FileType.HEIC: ["heic"],
|
||||||
|
FileType.HEIF: ["heif"],
|
||||||
|
FileType.HTML: ["html", "htm"],
|
||||||
|
FileType.JPEG: ["jpeg", "jpg"],
|
||||||
|
FileType.JSON: ["json"],
|
||||||
|
FileType.M4A: ["m4a"],
|
||||||
|
FileType.M4V: ["m4v"],
|
||||||
|
FileType.MOV: ["mov"],
|
||||||
|
FileType.MP3: ["mp3"],
|
||||||
|
FileType.MP4: ["mp4"],
|
||||||
|
FileType.MPEG: ["mpeg"],
|
||||||
|
FileType.MPEGPS: ["mpegps"],
|
||||||
|
FileType.MPG: ["mpg"],
|
||||||
|
FileType.MPA: ["mpa"],
|
||||||
|
FileType.MPGA: ["mpga"],
|
||||||
|
FileType.OGG: ["ogg"],
|
||||||
|
FileType.OPUS: ["opus"],
|
||||||
|
FileType.PDF: ["pdf"],
|
||||||
|
FileType.PCM: ["pcm"],
|
||||||
|
FileType.PNG: ["png"],
|
||||||
|
FileType.PPT: ["ppt"],
|
||||||
|
FileType.PPTX: ["pptx"],
|
||||||
|
FileType.RTF: ["rtf"],
|
||||||
|
FileType.THREE_GPP: ["3gpp"],
|
||||||
|
FileType.TXT: ["txt"],
|
||||||
|
FileType.WAV: ["wav"],
|
||||||
|
FileType.WEBM: ["webm"],
|
||||||
|
FileType.WEBP: ["webp"],
|
||||||
|
FileType.WMV: ["wmv"],
|
||||||
|
FileType.XLS: ["xls"],
|
||||||
|
FileType.XLSX: ["xlsx"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
FILE_MIME_TYPES: MappingProxyType[FileType, str] = MappingProxyType(
|
||||||
|
{
|
||||||
|
FileType.AAC: "audio/aac",
|
||||||
|
FileType.CSV: "text/csv",
|
||||||
|
FileType.DOC: "application/msword",
|
||||||
|
FileType.DOCX: "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
FileType.FLAC: "audio/flac",
|
||||||
|
FileType.FLV: "video/x-flv",
|
||||||
|
FileType.GIF: "image/gif",
|
||||||
|
FileType.GOOGLE_DOC: "application/vnd.google-apps.document",
|
||||||
|
FileType.GOOGLE_DRAWINGS: "application/vnd.google-apps.drawing",
|
||||||
|
FileType.GOOGLE_SHEETS: "application/vnd.google-apps.spreadsheet",
|
||||||
|
FileType.GOOGLE_SLIDES: "application/vnd.google-apps.presentation",
|
||||||
|
FileType.HEIC: "image/heic",
|
||||||
|
FileType.HEIF: "image/heif",
|
||||||
|
FileType.HTML: "text/html",
|
||||||
|
FileType.JPEG: "image/jpeg",
|
||||||
|
FileType.JSON: "application/json",
|
||||||
|
FileType.M4A: "audio/x-m4a",
|
||||||
|
FileType.M4V: "video/x-m4v",
|
||||||
|
FileType.MOV: "video/quicktime",
|
||||||
|
FileType.MP3: "audio/mpeg",
|
||||||
|
FileType.MP4: "video/mp4",
|
||||||
|
FileType.MPEG: "video/mpeg",
|
||||||
|
FileType.MPEGPS: "video/mpegps",
|
||||||
|
FileType.MPG: "video/mpg",
|
||||||
|
FileType.MPA: "audio/m4a",
|
||||||
|
FileType.MPGA: "audio/mpga",
|
||||||
|
FileType.OGG: "audio/ogg",
|
||||||
|
FileType.OPUS: "audio/opus",
|
||||||
|
FileType.PDF: "application/pdf",
|
||||||
|
FileType.PCM: "audio/pcm",
|
||||||
|
FileType.PNG: "image/png",
|
||||||
|
FileType.PPT: "application/vnd.ms-powerpoint",
|
||||||
|
FileType.PPTX: "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||||
|
FileType.RTF: "application/rtf",
|
||||||
|
FileType.THREE_GPP: "video/3gpp",
|
||||||
|
FileType.TXT: "text/plain",
|
||||||
|
FileType.WAV: "audio/wav",
|
||||||
|
FileType.WEBM: "video/webm",
|
||||||
|
FileType.WEBP: "image/webp",
|
||||||
|
FileType.WMV: "video/wmv",
|
||||||
|
FileType.XLS: "application/vnd.ms-excel",
|
||||||
|
FileType.XLSX: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Util Functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_mime_type_from_extension(extension: str) -> str:
|
||||||
|
for file_type, extensions in FILE_EXTENSIONS.items():
|
||||||
|
if extension in extensions:
|
||||||
|
return FILE_MIME_TYPES[file_type]
|
||||||
|
raise ValueError(f"Unknown mime type for extension: {extension}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_extension_from_mime_type(mime_type: str) -> str:
|
||||||
|
for file_type, mime in FILE_MIME_TYPES.items():
|
||||||
|
if mime == mime_type:
|
||||||
|
return FILE_EXTENSIONS[file_type][0]
|
||||||
|
raise ValueError(f"Unknown extension for mime type: {mime_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_type_from_extension(extension: str) -> FileType:
|
||||||
|
for file_type, extensions in FILE_EXTENSIONS.items():
|
||||||
|
if extension in extensions:
|
||||||
|
return file_type
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown file type for extension: {extension}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_extension_for_file_type(file_type: FileType) -> str:
|
||||||
|
return FILE_EXTENSIONS[file_type][0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_mime_type_for_file_type(file_type: FileType) -> str:
|
||||||
|
return FILE_MIME_TYPES[file_type]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
FileType Type Groupings (Videos, Images, etc)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Images
|
||||||
|
IMAGE_FILE_TYPES = {
|
||||||
|
FileType.PNG,
|
||||||
|
FileType.JPEG,
|
||||||
|
FileType.GIF,
|
||||||
|
FileType.WEBP,
|
||||||
|
FileType.HEIC,
|
||||||
|
FileType.HEIF,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_image_file_type(file_type):
|
||||||
|
return file_type in IMAGE_FILE_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
# Videos
|
||||||
|
VIDEO_FILE_TYPES = {
|
||||||
|
FileType.MOV,
|
||||||
|
FileType.MP4,
|
||||||
|
FileType.MPEG,
|
||||||
|
FileType.M4V,
|
||||||
|
FileType.FLV,
|
||||||
|
FileType.MPEGPS,
|
||||||
|
FileType.MPG,
|
||||||
|
FileType.WEBM,
|
||||||
|
FileType.WMV,
|
||||||
|
FileType.THREE_GPP,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_video_file_type(file_type):
|
||||||
|
return file_type in VIDEO_FILE_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
# Audio
|
||||||
|
AUDIO_FILE_TYPES = {
|
||||||
|
FileType.AAC,
|
||||||
|
FileType.FLAC,
|
||||||
|
FileType.MP3,
|
||||||
|
FileType.MPA,
|
||||||
|
FileType.MPGA,
|
||||||
|
FileType.OPUS,
|
||||||
|
FileType.PCM,
|
||||||
|
FileType.WAV,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_audio_file_type(file_type):
|
||||||
|
return file_type in AUDIO_FILE_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
# Text
|
||||||
|
TEXT_FILE_TYPES = {FileType.CSV, FileType.HTML, FileType.RTF, FileType.TXT}
|
||||||
|
|
||||||
|
|
||||||
|
def is_text_file_type(file_type):
|
||||||
|
return file_type in TEXT_FILE_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Other FileType Groupings
|
||||||
|
"""
|
||||||
|
# Accepted file types for GEMINI 1.5 through Vertex AI
|
||||||
|
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#gemini-send-multimodal-samples-images-nodejs
|
||||||
|
GEMINI_1_5_ACCEPTED_FILE_TYPES: Set[FileType] = {
|
||||||
|
# Image
|
||||||
|
FileType.PNG,
|
||||||
|
FileType.JPEG,
|
||||||
|
# Audio
|
||||||
|
FileType.AAC,
|
||||||
|
FileType.FLAC,
|
||||||
|
FileType.MP3,
|
||||||
|
FileType.MPA,
|
||||||
|
FileType.MPGA,
|
||||||
|
FileType.OPUS,
|
||||||
|
FileType.PCM,
|
||||||
|
FileType.WAV,
|
||||||
|
# Video
|
||||||
|
FileType.FLV,
|
||||||
|
FileType.MOV,
|
||||||
|
FileType.MPEG,
|
||||||
|
FileType.MPEGPS,
|
||||||
|
FileType.MPG,
|
||||||
|
FileType.MP4,
|
||||||
|
FileType.WEBM,
|
||||||
|
FileType.WMV,
|
||||||
|
FileType.THREE_GPP,
|
||||||
|
# PDF
|
||||||
|
FileType.PDF,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_gemini_1_5_accepted_file_type(file_type: FileType) -> bool:
|
||||||
|
return file_type in GEMINI_1_5_ACCEPTED_FILE_TYPES
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import TypedDict, Any, Union, Optional
|
from typing import TypedDict, Any, Union, Optional, Literal, List
|
||||||
import json
|
import json
|
||||||
from typing_extensions import (
|
from typing_extensions import (
|
||||||
Self,
|
Self,
|
||||||
|
@ -11,10 +11,137 @@ from typing_extensions import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SystemContentBlock(TypedDict):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSourceBlock(TypedDict):
|
||||||
|
bytes: Optional[str] # base 64 encoded string
|
||||||
|
|
||||||
|
|
||||||
|
class ImageBlock(TypedDict):
|
||||||
|
format: Literal["png", "jpeg", "gif", "webp"]
|
||||||
|
source: ImageSourceBlock
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResultContentBlock(TypedDict, total=False):
|
||||||
|
image: ImageBlock
|
||||||
|
json: dict
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResultBlock(TypedDict, total=False):
|
||||||
|
content: Required[List[ToolResultContentBlock]]
|
||||||
|
toolUseId: Required[str]
|
||||||
|
status: Literal["success", "error"]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolUseBlock(TypedDict):
|
||||||
|
input: dict
|
||||||
|
name: str
|
||||||
|
toolUseId: str
|
||||||
|
|
||||||
|
|
||||||
|
class ContentBlock(TypedDict, total=False):
|
||||||
|
text: str
|
||||||
|
image: ImageBlock
|
||||||
|
toolResult: ToolResultBlock
|
||||||
|
toolUse: ToolUseBlock
|
||||||
|
|
||||||
|
|
||||||
|
class MessageBlock(TypedDict):
|
||||||
|
content: List[ContentBlock]
|
||||||
|
role: Literal["user", "assistant"]
|
||||||
|
|
||||||
|
|
||||||
|
class ConverseMetricsBlock(TypedDict):
|
||||||
|
latencyMs: float # time in ms
|
||||||
|
|
||||||
|
|
||||||
|
class ConverseResponseOutputBlock(TypedDict):
|
||||||
|
message: Optional[MessageBlock]
|
||||||
|
|
||||||
|
|
||||||
|
class ConverseTokenUsageBlock(TypedDict):
|
||||||
|
inputTokens: int
|
||||||
|
outputTokens: int
|
||||||
|
totalTokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class ConverseResponseBlock(TypedDict):
|
||||||
|
additionalModelResponseFields: dict
|
||||||
|
metrics: ConverseMetricsBlock
|
||||||
|
output: ConverseResponseOutputBlock
|
||||||
|
stopReason: (
|
||||||
|
str # end_turn | tool_use | max_tokens | stop_sequence | content_filtered
|
||||||
|
)
|
||||||
|
usage: ConverseTokenUsageBlock
|
||||||
|
|
||||||
|
|
||||||
|
class ToolInputSchemaBlock(TypedDict):
|
||||||
|
json: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolSpecBlock(TypedDict, total=False):
|
||||||
|
inputSchema: Required[ToolInputSchemaBlock]
|
||||||
|
name: Required[str]
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolBlock(TypedDict):
|
||||||
|
toolSpec: Optional[ToolSpecBlock]
|
||||||
|
|
||||||
|
|
||||||
|
class SpecificToolChoiceBlock(TypedDict):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolChoiceValuesBlock(TypedDict, total=False):
|
||||||
|
any: dict
|
||||||
|
auto: dict
|
||||||
|
tool: SpecificToolChoiceBlock
|
||||||
|
|
||||||
|
|
||||||
|
class ToolConfigBlock(TypedDict, total=False):
|
||||||
|
tools: Required[List[ToolBlock]]
|
||||||
|
toolChoice: Union[str, ToolChoiceValuesBlock]
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceConfig(TypedDict, total=False):
|
||||||
|
maxTokens: int
|
||||||
|
stopSequences: List[str]
|
||||||
|
temperature: float
|
||||||
|
topP: float
|
||||||
|
|
||||||
|
|
||||||
|
class ToolBlockDeltaEvent(TypedDict):
|
||||||
|
input: str
|
||||||
|
|
||||||
|
|
||||||
|
class ContentBlockDeltaEvent(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Either 'text' or 'toolUse' will be specified for Converse API streaming response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
toolUse: ToolBlockDeltaEvent
|
||||||
|
|
||||||
|
|
||||||
|
class RequestObject(TypedDict, total=False):
|
||||||
|
additionalModelRequestFields: dict
|
||||||
|
additionalModelResponseFieldPaths: List[str]
|
||||||
|
inferenceConfig: InferenceConfig
|
||||||
|
messages: Required[List[MessageBlock]]
|
||||||
|
system: List[SystemContentBlock]
|
||||||
|
toolConfig: ToolConfigBlock
|
||||||
|
|
||||||
|
|
||||||
class GenericStreamingChunk(TypedDict):
|
class GenericStreamingChunk(TypedDict):
|
||||||
text: Required[str]
|
text: Required[str]
|
||||||
|
tool_str: Required[str]
|
||||||
is_finished: Required[bool]
|
is_finished: Required[bool]
|
||||||
finish_reason: Required[str]
|
finish_reason: Required[str]
|
||||||
|
usage: Optional[ConverseTokenUsageBlock]
|
||||||
|
|
||||||
|
|
||||||
class Document(TypedDict):
|
class Document(TypedDict):
|
||||||
|
|
|
@ -293,3 +293,20 @@ class ListBatchRequest(TypedDict, total=False):
|
||||||
extra_headers: Optional[Dict[str, str]]
|
extra_headers: Optional[Dict[str, str]]
|
||||||
extra_body: Optional[Dict[str, str]]
|
extra_body: Optional[Dict[str, str]]
|
||||||
timeout: Optional[float]
|
timeout: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionToolCallFunctionChunk(TypedDict):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionToolCallChunk(TypedDict):
|
||||||
|
id: str
|
||||||
|
type: Literal["function"]
|
||||||
|
function: ChatCompletionToolCallFunctionChunk
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseMessage(TypedDict, total=False):
|
||||||
|
content: Optional[str]
|
||||||
|
tool_calls: List[ChatCompletionToolCallChunk]
|
||||||
|
role: Literal["assistant"]
|
||||||
|
|
|
@ -3,7 +3,7 @@ from pydantic import BaseModel, Field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class ServiceTypes(enum.Enum):
|
class ServiceTypes(str, enum.Enum):
|
||||||
"""
|
"""
|
||||||
Enum for litellm + litellm-adjacent services (redis/postgres/etc.)
|
Enum for litellm + litellm-adjacent services (redis/postgres/etc.)
|
||||||
"""
|
"""
|
||||||
|
|
123
litellm/utils.py
123
litellm/utils.py
|
@ -239,6 +239,8 @@ def map_finish_reason(
|
||||||
return "length"
|
return "length"
|
||||||
elif finish_reason == "tool_use": # anthropic
|
elif finish_reason == "tool_use": # anthropic
|
||||||
return "tool_calls"
|
return "tool_calls"
|
||||||
|
elif finish_reason == "content_filtered":
|
||||||
|
return "content_filter"
|
||||||
return finish_reason
|
return finish_reason
|
||||||
|
|
||||||
|
|
||||||
|
@ -5655,19 +5657,29 @@ def get_optional_params(
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
elif "anthropic" in model:
|
elif "anthropic" in model:
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
# anthropic params on bedrock
|
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
|
||||||
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
|
if model.startswith("anthropic.claude-3"):
|
||||||
if model.startswith("anthropic.claude-3"):
|
optional_params = (
|
||||||
optional_params = (
|
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
)
|
else: # bedrock httpx route
|
||||||
else:
|
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
model=model,
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif "amazon" in model: # amazon titan llms
|
elif "amazon" in model: # amazon titan llms
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
@ -6445,20 +6457,7 @@ def get_supported_openai_params(
|
||||||
- None if unmapped
|
- None if unmapped
|
||||||
"""
|
"""
|
||||||
if custom_llm_provider == "bedrock":
|
if custom_llm_provider == "bedrock":
|
||||||
if model.startswith("anthropic.claude-3"):
|
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
|
||||||
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
|
|
||||||
elif model.startswith("anthropic"):
|
|
||||||
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
|
|
||||||
elif model.startswith("ai21"):
|
|
||||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
|
||||||
elif model.startswith("amazon"):
|
|
||||||
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
|
||||||
elif model.startswith("meta"):
|
|
||||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
|
||||||
elif model.startswith("cohere"):
|
|
||||||
return ["stream", "temperature", "max_tokens"]
|
|
||||||
elif model.startswith("mistral"):
|
|
||||||
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
|
||||||
elif custom_llm_provider == "ollama":
|
elif custom_llm_provider == "ollama":
|
||||||
return litellm.OllamaConfig().get_supported_openai_params()
|
return litellm.OllamaConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "ollama_chat":
|
elif custom_llm_provider == "ollama_chat":
|
||||||
|
@ -8558,7 +8557,11 @@ def exception_type(
|
||||||
extra_information = f"\nModel: {model}"
|
extra_information = f"\nModel: {model}"
|
||||||
if _api_base:
|
if _api_base:
|
||||||
extra_information += f"\nAPI Base: `{_api_base}`"
|
extra_information += f"\nAPI Base: `{_api_base}`"
|
||||||
if messages and len(messages) > 0:
|
if (
|
||||||
|
messages
|
||||||
|
and len(messages) > 0
|
||||||
|
and litellm.redact_messages_in_exceptions is False
|
||||||
|
):
|
||||||
extra_information += f"\nMessages: `{messages}`"
|
extra_information += f"\nMessages: `{messages}`"
|
||||||
|
|
||||||
if _model_group is not None:
|
if _model_group is not None:
|
||||||
|
@ -9124,7 +9127,7 @@ def exception_type(
|
||||||
if "Unable to locate credentials" in error_str:
|
if "Unable to locate credentials" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise BadRequestError(
|
raise BadRequestError(
|
||||||
message=f"SagemakerException - {error_str}",
|
message=f"litellm.BadRequestError: SagemakerException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="sagemaker",
|
llm_provider="sagemaker",
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
|
@ -9158,10 +9161,16 @@ def exception_type(
|
||||||
):
|
):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise BadRequestError(
|
raise BadRequestError(
|
||||||
message=f"VertexAIException BadRequestError - {error_str}",
|
message=f"litellm.BadRequestError: VertexAIException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="vertex_ai",
|
llm_provider="vertex_ai",
|
||||||
response=original_exception.response,
|
response=httpx.Response(
|
||||||
|
status_code=429,
|
||||||
|
request=httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url=" https://cloud.google.com/vertex-ai/",
|
||||||
|
),
|
||||||
|
),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
|
@ -9169,12 +9178,19 @@ def exception_type(
|
||||||
or "Content has no parts." in error_str
|
or "Content has no parts." in error_str
|
||||||
):
|
):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise APIError(
|
raise litellm.InternalServerError(
|
||||||
message=f"VertexAIException APIError - {error_str}",
|
message=f"litellm.InternalServerError: VertexAIException - {error_str}",
|
||||||
status_code=500,
|
status_code=500,
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="vertex_ai",
|
llm_provider="vertex_ai",
|
||||||
request=original_exception.request,
|
request=(
|
||||||
|
original_exception.request
|
||||||
|
if hasattr(original_exception, "request")
|
||||||
|
else httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url=" https://cloud.google.com/vertex-ai/",
|
||||||
|
)
|
||||||
|
),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif "403" in error_str:
|
elif "403" in error_str:
|
||||||
|
@ -9183,7 +9199,13 @@ def exception_type(
|
||||||
message=f"VertexAIException BadRequestError - {error_str}",
|
message=f"VertexAIException BadRequestError - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="vertex_ai",
|
llm_provider="vertex_ai",
|
||||||
response=original_exception.response,
|
response=httpx.Response(
|
||||||
|
status_code=429,
|
||||||
|
request=httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url=" https://cloud.google.com/vertex-ai/",
|
||||||
|
),
|
||||||
|
),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif "The response was blocked." in error_str:
|
elif "The response was blocked." in error_str:
|
||||||
|
@ -9230,12 +9252,18 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="vertex_ai",
|
llm_provider="vertex_ai",
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
response=original_exception.response,
|
response=httpx.Response(
|
||||||
|
status_code=429,
|
||||||
|
request=httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url=" https://cloud.google.com/vertex-ai/",
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if original_exception.status_code == 500:
|
if original_exception.status_code == 500:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise APIError(
|
raise litellm.InternalServerError(
|
||||||
message=f"VertexAIException APIError - {error_str}",
|
message=f"VertexAIException InternalServerError - {error_str}",
|
||||||
status_code=500,
|
status_code=500,
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="vertex_ai",
|
llm_provider="vertex_ai",
|
||||||
|
@ -11423,12 +11451,27 @@ class CustomStreamWrapper:
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider == "bedrock":
|
elif self.custom_llm_provider == "bedrock":
|
||||||
|
from litellm.types.llms.bedrock import GenericStreamingChunk
|
||||||
|
|
||||||
if self.received_finish_reason is not None:
|
if self.received_finish_reason is not None:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
response_obj = self.handle_bedrock_stream(chunk)
|
response_obj: GenericStreamingChunk = chunk
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.stream_options
|
||||||
|
and self.stream_options.get("include_usage", False) is True
|
||||||
|
and response_obj["usage"] is not None
|
||||||
|
):
|
||||||
|
self.sent_stream_usage = True
|
||||||
|
model_response.usage = litellm.Usage(
|
||||||
|
prompt_tokens=response_obj["usage"]["inputTokens"],
|
||||||
|
completion_tokens=response_obj["usage"]["outputTokens"],
|
||||||
|
total_tokens=response_obj["usage"]["totalTokens"],
|
||||||
|
)
|
||||||
elif self.custom_llm_provider == "sagemaker":
|
elif self.custom_llm_provider == "sagemaker":
|
||||||
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
||||||
response_obj = self.handle_sagemaker_stream(chunk)
|
response_obj = self.handle_sagemaker_stream(chunk)
|
||||||
|
@ -11695,7 +11738,7 @@ class CustomStreamWrapper:
|
||||||
and hasattr(model_response, "usage")
|
and hasattr(model_response, "usage")
|
||||||
and hasattr(model_response.usage, "prompt_tokens")
|
and hasattr(model_response.usage, "prompt_tokens")
|
||||||
):
|
):
|
||||||
if self.sent_first_chunk == False:
|
if self.sent_first_chunk is False:
|
||||||
completion_obj["role"] = "assistant"
|
completion_obj["role"] = "assistant"
|
||||||
self.sent_first_chunk = True
|
self.sent_first_chunk = True
|
||||||
model_response.choices[0].delta = Delta(**completion_obj)
|
model_response.choices[0].delta = Delta(**completion_obj)
|
||||||
|
@ -11863,6 +11906,8 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
try:
|
try:
|
||||||
|
if self.completion_stream is None:
|
||||||
|
self.fetch_sync_stream()
|
||||||
while True:
|
while True:
|
||||||
if (
|
if (
|
||||||
isinstance(self.completion_stream, str)
|
isinstance(self.completion_stream, str)
|
||||||
|
@ -11937,6 +11982,14 @@ class CustomStreamWrapper:
|
||||||
custom_llm_provider=self.custom_llm_provider,
|
custom_llm_provider=self.custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def fetch_sync_stream(self):
|
||||||
|
if self.completion_stream is None and self.make_call is not None:
|
||||||
|
# Call make_call to get the completion stream
|
||||||
|
self.completion_stream = self.make_call(client=litellm.module_level_client)
|
||||||
|
self._stream_iter = self.completion_stream.__iter__()
|
||||||
|
|
||||||
|
return self.completion_stream
|
||||||
|
|
||||||
async def fetch_stream(self):
|
async def fetch_stream(self):
|
||||||
if self.completion_stream is None and self.make_call is not None:
|
if self.completion_stream is None and self.make_call is not None:
|
||||||
# Call make_call to get the completion stream
|
# Call make_call to get the completion stream
|
||||||
|
|
|
@ -5,6 +5,10 @@ description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
packages = [
|
||||||
|
{ include = "litellm" },
|
||||||
|
{ include = "litellm/py.typed"},
|
||||||
|
]
|
||||||
|
|
||||||
[tool.poetry.urls]
|
[tool.poetry.urls]
|
||||||
homepage = "https://litellm.ai"
|
homepage = "https://litellm.ai"
|
||||||
|
|
|
@ -1 +1,3 @@
|
||||||
ignore = ["F403", "F401"]
|
ignore = ["F405"]
|
||||||
|
extend-select = ["E501"]
|
||||||
|
line-length = 120
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue