diff --git a/.gitignore b/.gitignore index 8d99ae8af8..69061d62d3 100644 --- a/.gitignore +++ b/.gitignore @@ -59,3 +59,4 @@ myenv/* litellm/proxy/_experimental/out/404/index.html litellm/proxy/_experimental/out/model_hub/index.html litellm/proxy/_experimental/out/onboarding/index.html +litellm/tests/log.txt diff --git a/docs/my-website/docs/proxy/alerting.md b/docs/my-website/docs/proxy/alerting.md index 3ef676bbd6..402de410c9 100644 --- a/docs/my-website/docs/proxy/alerting.md +++ b/docs/my-website/docs/proxy/alerting.md @@ -62,6 +62,23 @@ curl -X GET 'http://localhost:4000/health/services?service=slack' \ -H 'Authorization: Bearer sk-1234' ``` +## Advanced - Redacting Messages from Alerts + +By default alerts show the `messages/input` passed to the LLM. If you want to redact this from slack alerting set the following setting on your config + + +```shell +general_settings: + alerting: ["slack"] + alert_types: ["spend_reports"] + +litellm_settings: + redact_messages_in_exceptions: True +``` + + + + ## Advanced - Opting into specific alert types Set `alert_types` if you want to Opt into only specific alert types diff --git a/litellm/__init__.py b/litellm/__init__.py index 219af0ea18..b6e6d97dc8 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -5,7 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.* ### INIT VARIABLES ### import threading, requests, os from typing import Callable, List, Optional, Dict, Union, Any, Literal -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.caching import Cache from litellm._logging import ( set_verbose, @@ -60,6 +60,7 @@ _async_failure_callback: List[Callable] = ( pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] turn_off_message_logging: Optional[bool] = False +redact_messages_in_exceptions: Optional[bool] = False store_audit_logs = False # Enterprise feature, allow users to see audit logs ## end of callbacks ############# @@ -233,6 +234,7 @@ max_end_user_budget: Optional[float] = None #### RELIABILITY #### request_timeout: float = 6000 module_level_aclient = AsyncHTTPHandler(timeout=request_timeout) +module_level_client = HTTPHandler(timeout=request_timeout) num_retries: Optional[int] = None # per model endpoint default_fallbacks: Optional[List] = None fallbacks: Optional[List] = None @@ -766,7 +768,7 @@ from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig -from .llms.bedrock_httpx import AmazonCohereChatConfig +from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig from .llms.bedrock import ( AmazonTitanConfig, AmazonAI21Config, diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index dc6f35642b..fd14b3cdeb 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -1,10 +1,18 @@ -import litellm, traceback +from datetime import datetime +import litellm from litellm.proxy._types import UserAPIKeyAuth from .types.services import ServiceTypes, ServiceLoggerPayload from .integrations.prometheus_services import PrometheusServicesLogger from .integrations.custom_logger import CustomLogger from datetime import timedelta -from typing import Union +from typing import Union, Optional, TYPE_CHECKING, Any + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any class ServiceLogging(CustomLogger): @@ -40,7 +48,13 @@ class ServiceLogging(CustomLogger): self.mock_testing_sync_failure_hook += 1 async def async_service_success_hook( - self, service: ServiceTypes, duration: float, call_type: str + self, + service: ServiceTypes, + call_type: str, + duration: float, + parent_otel_span: Optional[Span] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, ): """ - For counting if the redis, postgres call is successful @@ -61,6 +75,16 @@ class ServiceLogging(CustomLogger): payload=payload ) + from litellm.proxy.proxy_server import open_telemetry_logger + + if parent_otel_span is not None and open_telemetry_logger is not None: + await open_telemetry_logger.async_service_success_hook( + payload=payload, + parent_otel_span=parent_otel_span, + start_time=start_time, + end_time=end_time, + ) + async def async_service_failure_hook( self, service: ServiceTypes, diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 00bc014c40..089b67368a 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -1,9 +1,21 @@ import os -from typing import Optional from dataclasses import dataclass +from datetime import datetime from litellm.integrations.custom_logger import CustomLogger from litellm._logging import verbose_logger +from litellm.types.services import ServiceLoggerPayload +from typing import Union, Optional, TYPE_CHECKING, Any + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth + + Span = _Span + UserAPIKeyAuth = _UserAPIKeyAuth +else: + Span = Any + UserAPIKeyAuth = Any LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm") @@ -77,6 +89,56 @@ class OpenTelemetry(CustomLogger): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): self._handle_failure(kwargs, response_obj, start_time, end_time) + async def async_service_success_hook( + self, + payload: ServiceLoggerPayload, + parent_otel_span: Optional[Span] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + ): + from opentelemetry import trace + from datetime import datetime + from opentelemetry.trace import Status, StatusCode + + if parent_otel_span is not None: + _span_name = payload.service + service_logging_span = self.tracer.start_span( + name=_span_name, + context=trace.set_span_in_context(parent_otel_span), + start_time=self._to_ns(start_time), + ) + service_logging_span.set_attribute(key="call_type", value=payload.call_type) + service_logging_span.set_attribute( + key="service", value=payload.service.value + ) + service_logging_span.set_status(Status(StatusCode.OK)) + service_logging_span.end(end_time=self._to_ns(end_time)) + + async def async_post_call_failure_hook( + self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth + ): + from opentelemetry.trace import Status, StatusCode + from opentelemetry import trace + + parent_otel_span = user_api_key_dict.parent_otel_span + if parent_otel_span is not None: + parent_otel_span.set_status(Status(StatusCode.ERROR)) + _span_name = "Failed Proxy Server Request" + + # Exception Logging Child Span + exception_logging_span = self.tracer.start_span( + name=_span_name, + context=trace.set_span_in_context(parent_otel_span), + ) + exception_logging_span.set_attribute( + key="exception", value=str(original_exception) + ) + exception_logging_span.set_status(Status(StatusCode.ERROR)) + exception_logging_span.end(end_time=self._to_ns(datetime.now())) + + # End Parent OTEL Sspan + parent_otel_span.end(end_time=self._to_ns(datetime.now())) + def _handle_sucess(self, kwargs, response_obj, start_time, end_time): from opentelemetry.trace import Status, StatusCode @@ -85,15 +147,18 @@ class OpenTelemetry(CustomLogger): kwargs, self.config, ) + _parent_context, parent_otel_span = self._get_span_context(kwargs) span = self.tracer.start_span( name=self._get_span_name(kwargs), start_time=self._to_ns(start_time), - context=self._get_span_context(kwargs), + context=_parent_context, ) span.set_status(Status(StatusCode.OK)) self.set_attributes(span, kwargs, response_obj) span.end(end_time=self._to_ns(end_time)) + if parent_otel_span is not None: + parent_otel_span.end(end_time=self._to_ns(datetime.now())) def _handle_failure(self, kwargs, response_obj, start_time, end_time): from opentelemetry.trace import Status, StatusCode @@ -122,17 +187,28 @@ class OpenTelemetry(CustomLogger): from opentelemetry.trace.propagation.tracecontext import ( TraceContextTextMapPropagator, ) + from opentelemetry import trace litellm_params = kwargs.get("litellm_params", {}) or {} proxy_server_request = litellm_params.get("proxy_server_request", {}) or {} headers = proxy_server_request.get("headers", {}) or {} traceparent = headers.get("traceparent", None) + _metadata = litellm_params.get("metadata", {}) + parent_otel_span = _metadata.get("litellm_parent_otel_span", None) + + """ + Two way to use parents in opentelemetry + - using the traceparent header + - using the parent_otel_span in the [metadata][parent_otel_span] + """ + if parent_otel_span is not None: + return trace.set_span_in_context(parent_otel_span), parent_otel_span if traceparent is None: - return None + return None, None else: carrier = {"traceparent": traceparent} - return TraceContextTextMapPropagator().extract(carrier=carrier) + return TraceContextTextMapPropagator().extract(carrier=carrier), None def _get_span_processor(self): from opentelemetry.sdk.trace.export import ( diff --git a/litellm/integrations/slack_alerting.py b/litellm/integrations/slack_alerting.py index c98d60f1fd..21415fb6d6 100644 --- a/litellm/integrations/slack_alerting.py +++ b/litellm/integrations/slack_alerting.py @@ -326,8 +326,8 @@ class SlackAlerting(CustomLogger): end_time=end_time, ) ) - if litellm.turn_off_message_logging: - messages = "Message not logged. `litellm.turn_off_message_logging=True`." + if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions: + messages = "Message not logged. litellm.redact_messages_in_exceptions=True" request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`" slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`" if time_difference_float > self.alerting_threshold: @@ -567,9 +567,12 @@ class SlackAlerting(CustomLogger): except: messages = "" - if litellm.turn_off_message_logging: + if ( + litellm.turn_off_message_logging + or litellm.redact_messages_in_exceptions + ): messages = ( - "Message not logged. `litellm.turn_off_message_logging=True`." + "Message not logged. litellm.redact_messages_in_exceptions=True" ) request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`" else: diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index dbd7e7c695..afc2657610 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -38,6 +38,8 @@ from .prompt_templates.factory import ( extract_between_tags, parse_xml_params, contains_tag, + _bedrock_converse_messages_pt, + _bedrock_tools_pt, ) from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from .base import BaseLLM @@ -45,6 +47,11 @@ import httpx # type: ignore from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator from litellm.types.llms.bedrock import * import urllib.parse +from litellm.types.llms.openai import ( + ChatCompletionResponseMessage, + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, +) class AmazonCohereChatConfig: @@ -118,6 +125,8 @@ class AmazonCohereChatConfig: "presence_penalty", "seed", "stop", + "tools", + "tool_choice", ] def map_openai_params( @@ -169,7 +178,38 @@ async def make_call( logging_obj.post_call( input=messages, api_key="", - original_response=completion_stream, # Pass the completion stream for logging + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + +def make_sync_call( + client: Optional[HTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + if client is None: + client = HTTPHandler() # Create a new client if none provided + + response = client.post(api_base, headers=headers, data=data, stream=True) + + if response.status_code != 200: + raise BedrockError(status_code=response.status_code, message=response.read()) + + decoder = AWSEventStreamDecoder(model=model) + completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", additional_args={"complete_input_dict": data}, ) @@ -1000,12 +1040,12 @@ class BedrockLLM(BaseLLM): if isinstance(timeout, float) or isinstance(timeout, int): timeout = httpx.Timeout(timeout) _params["timeout"] = timeout - self.client = AsyncHTTPHandler(**_params) # type: ignore + client = AsyncHTTPHandler(**_params) # type: ignore else: - self.client = client # type: ignore + client = client # type: ignore try: - response = await self.client.post(api_base, headers=headers, data=data) # type: ignore + response = await client.post(api_base, headers=headers, data=data) # type: ignore response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code @@ -1069,6 +1109,738 @@ class BedrockLLM(BaseLLM): return super().embedding(*args, **kwargs) +class AmazonConverseConfig: + """ + Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + #2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features + """ + + maxTokens: Optional[int] + stopSequences: Optional[List[str]] + temperature: Optional[int] + topP: Optional[int] + + def __init__( + self, + maxTokens: Optional[int] = None, + stopSequences: Optional[List[str]] = None, + temperature: Optional[int] = None, + topP: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self, model: str) -> List[str]: + supported_params = [ + "max_tokens", + "stream", + "stream_options", + "stop", + "temperature", + "top_p", + "extra_headers", + ] + + if ( + model.startswith("anthropic") + or model.startswith("mistral") + or model.startswith("cohere") + ): + supported_params.append("tools") + + if model.startswith("anthropic") or model.startswith("mistral"): + # only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + supported_params.append("tool_choice") + + return supported_params + + def map_tool_choice_values( + self, model: str, tool_choice: Union[str, dict], drop_params: bool + ) -> Optional[ToolChoiceValuesBlock]: + if tool_choice == "none": + if litellm.drop_params is True or drop_params is True: + return None + else: + raise litellm.utils.UnsupportedParamsError( + message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + elif tool_choice == "required": + return ToolChoiceValuesBlock(any={}) + elif tool_choice == "auto": + return ToolChoiceValuesBlock(auto={}) + elif isinstance(tool_choice, dict): + # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + specific_tool = SpecificToolChoiceBlock( + name=tool_choice.get("function", {}).get("name", "") + ) + return ToolChoiceValuesBlock(tool=specific_tool) + else: + raise litellm.utils.UnsupportedParamsError( + message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + + def get_supported_image_types(self) -> List[str]: + return ["png", "jpeg", "gif", "webp"] + + def map_openai_params( + self, + model: str, + non_default_params: dict, + optional_params: dict, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["maxTokens"] = value + if param == "stream": + optional_params["stream"] = value + if param == "stop": + if isinstance(value, str): + value = [value] + optional_params["stop_sequences"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["topP"] = value + if param == "tools": + optional_params["tools"] = value + if param == "tool_choice": + _tool_choice_value = self.map_tool_choice_values( + model=model, tool_choice=value, drop_params=drop_params # type: ignore + ) + if _tool_choice_value is not None: + optional_params["tool_choice"] = _tool_choice_value + return optional_params + + +class BedrockConverseLLM(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def process_response( + self, + model: str, + response: Union[requests.Response, httpx.Response], + model_response: ModelResponse, + stream: bool, + logging_obj: Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + ) -> Union[ModelResponse, CustomStreamWrapper]: + + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + + ## RESPONSE OBJECT + try: + completion_response = ConverseResponseBlock(**response.json()) # type: ignore + except Exception as e: + raise BedrockError( + message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( + response.text, str(e) + ), + status_code=422, + ) + + """ + Bedrock Response Object has optional message block + + completion_response["output"].get("message", None) + + A message block looks like this (Example 1): + "output": { + "message": { + "role": "assistant", + "content": [ + { + "text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?" + } + ] + } + }, + (Example 2): + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA", + "name": "top_song", + "input": { + "sign": "WZPZ" + } + } + } + ] + } + } + + """ + message: Optional[MessageBlock] = completion_response["output"]["message"] + chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"} + content_str = "" + tools: List[ChatCompletionToolCallChunk] = [] + if message is not None: + for content in message["content"]: + """ + - Content is either a tool response or text + """ + if "text" in content: + content_str += content["text"] + if "toolUse" in content: + _function_chunk = ChatCompletionToolCallFunctionChunk( + name=content["toolUse"]["name"], + arguments=json.dumps(content["toolUse"]["input"]), + ) + _tool_response_chunk = ChatCompletionToolCallChunk( + id=content["toolUse"]["toolUseId"], + type="function", + function=_function_chunk, + ) + tools.append(_tool_response_chunk) + chat_completion_message["content"] = content_str + chat_completion_message["tool_calls"] = tools + + ## CALCULATING USAGE - bedrock returns usage in the headers + input_tokens = completion_response["usage"]["inputTokens"] + output_tokens = completion_response["usage"]["outputTokens"] + total_tokens = completion_response["usage"]["totalTokens"] + + model_response.choices = [ + litellm.Choices( + finish_reason=map_finish_reason(completion_response["stopReason"]), + index=0, + message=litellm.Message(**chat_completion_message), + ) + ] + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + total_tokens=total_tokens, + ) + setattr(model_response, "usage", usage) + + return model_response + + def encode_model_id(self, model_id: str) -> str: + """ + Double encode the model ID to ensure it matches the expected double-encoded format. + Args: + model_id (str): The model ID to encode. + Returns: + str: The double-encoded model ID. + """ + return urllib.parse.quote(model_id, safe="") + + def get_credentials( + self, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_session_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + aws_role_name: Optional[str] = None, + aws_web_identity_token: Optional[str] = None, + ): + """ + Return a boto3.Credentials object + """ + import boto3 + + ## CHECK IS 'os.environ/' passed in + params_to_check: List[Optional[str]] = [ + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + ] + + # Iterate over parameters and update if needed + for i, param in enumerate(params_to_check): + if param and param.startswith("os.environ/"): + _v = get_secret(param) + if _v is not None and isinstance(_v, str): + params_to_check[i] = _v + # Assign updated values back to parameters + ( + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + ) = params_to_check + + ### CHECK STS ### + if ( + aws_web_identity_token is not None + and aws_role_name is not None + and aws_session_name is not None + ): + oidc_token = get_secret(aws_web_identity_token) + + if oidc_token is None: + raise BedrockError( + message="OIDC token could not be retrieved from secret manager.", + status_code=401, + ) + + sts_client = boto3.client("sts") + + # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=aws_role_name, + RoleSessionName=aws_session_name, + WebIdentityToken=oidc_token, + DurationSeconds=3600, + ) + + session = boto3.Session( + aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], + aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], + aws_session_token=sts_response["Credentials"]["SessionToken"], + region_name=aws_region_name, + ) + + return session.get_credentials() + elif aws_role_name is not None and aws_session_name is not None: + sts_client = boto3.client( + "sts", + aws_access_key_id=aws_access_key_id, # [OPTIONAL] + aws_secret_access_key=aws_secret_access_key, # [OPTIONAL] + ) + + sts_response = sts_client.assume_role( + RoleArn=aws_role_name, RoleSessionName=aws_session_name + ) + + # Extract the credentials from the response and convert to Session Credentials + sts_credentials = sts_response["Credentials"] + from botocore.credentials import Credentials + + credentials = Credentials( + access_key=sts_credentials["AccessKeyId"], + secret_key=sts_credentials["SecretAccessKey"], + token=sts_credentials["SessionToken"], + ) + return credentials + elif aws_profile_name is not None: ### CHECK SESSION ### + # uses auth values from AWS profile usually stored in ~/.aws/credentials + client = boto3.Session(profile_name=aws_profile_name) + + return client.get_credentials() + else: + session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region_name, + ) + + return session.get_credentials() + + async def async_streaming( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> CustomStreamWrapper: + streaming_response = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_call, + client=client, + api_base=api_base, + headers=headers, + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + ), + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + return streaming_response + + async def async_completion( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> Union[ModelResponse, CustomStreamWrapper]: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = AsyncHTTPHandler(**_params) # type: ignore + else: + client = client # type: ignore + + try: + response = await client.post(api_base, headers=headers, data=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException as e: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=stream if isinstance(stream, bool) else False, + logging_obj=logging_obj, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) + + def completion( + self, + model: str, + messages: list, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: dict, + acompletion: bool, + timeout: Optional[Union[float, httpx.Timeout]], + litellm_params=None, + logger_fn=None, + extra_headers: Optional[dict] = None, + client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, + ): + try: + import boto3 + + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + ## SETUP ## + stream = optional_params.pop("stream", None) + modelId = optional_params.pop("model_id", None) + if modelId is not None: + modelId = self.encode_model_id(model_id=modelId) + else: + modelId = model + + provider = model.split(".")[0] + + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_profile_name = optional_params.pop("aws_profile_name", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # https://bedrock-runtime.{region_name}.amazonaws.com + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) + + ### SET REGION NAME ### + if aws_region_name is None: + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + + if litellm_aws_region_name is not None and isinstance( + litellm_aws_region_name, str + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret("AWS_REGION", None) + if standard_aws_region_name is not None and isinstance( + standard_aws_region_name, str + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + credentials: Credentials = self.get_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + aws_web_identity_token=aws_web_identity_token, + ) + + ### SET RUNTIME ENDPOINT ### + endpoint_url = "" + env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") + if aws_bedrock_runtime_endpoint is not None and isinstance( + aws_bedrock_runtime_endpoint, str + ): + endpoint_url = aws_bedrock_runtime_endpoint + elif env_aws_bedrock_runtime_endpoint and isinstance( + env_aws_bedrock_runtime_endpoint, str + ): + endpoint_url = env_aws_bedrock_runtime_endpoint + else: + endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + + if (stream is not None and stream is True) and provider != "ai21": + endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" + else: + endpoint_url = f"{endpoint_url}/model/{modelId}/converse" + + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + + # Separate system prompt from rest of message + system_prompt_indices = [] + system_content_blocks: List[SystemContentBlock] = [] + for idx, message in enumerate(messages): + if message["role"] == "system": + _system_content_block = SystemContentBlock(text=message["content"]) + system_content_blocks.append(_system_content_block) + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + + inference_params = copy.deepcopy(optional_params) + additional_request_keys = [] + additional_request_params = {} + supported_converse_params = AmazonConverseConfig.__annotations__.keys() + supported_tool_call_params = ["tools", "tool_choice"] + ## TRANSFORMATION ## + # send all model-specific params in 'additional_request_params' + for k, v in inference_params.items(): + if ( + k not in supported_converse_params + and k not in supported_tool_call_params + ): + additional_request_params[k] = v + additional_request_keys.append(k) + for key in additional_request_keys: + inference_params.pop(key, None) + + bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( + messages=messages + ) + bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( + inference_params.pop("tools", []) + ) + bedrock_tool_config: Optional[ToolConfigBlock] = None + if len(bedrock_tools) > 0: + tool_choice_values: ToolChoiceValuesBlock = inference_params.pop( + "tool_choice", None + ) + bedrock_tool_config = ToolConfigBlock( + tools=bedrock_tools, + ) + if tool_choice_values is not None: + bedrock_tool_config["toolChoice"] = tool_choice_values + + _data: RequestObject = { + "messages": bedrock_messages, + "additionalModelRequestFields": additional_request_params, + "system": system_content_blocks, + "inferenceConfig": InferenceConfig(**inference_params), + } + if bedrock_tool_config is not None: + _data["toolConfig"] = bedrock_tool_config + data = json.dumps(_data) + ## COMPLETION CALL + + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( + method="POST", url=endpoint_url, data=data, headers=headers + ) + sigv4.add_auth(request) + prepped = request.prepare() + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": prepped.url, + "headers": prepped.headers, + }, + ) + + ### ROUTING (ASYNC, STREAMING, SYNC) + if acompletion: + if isinstance(client, HTTPHandler): + client = None + if stream is True and provider != "ai21": + return self.async_streaming( + model=model, + messages=messages, + data=data, + api_base=prepped.url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=True, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=data, + api_base=prepped.url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, # type: ignore + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore + + if (stream is not None and stream is True) and provider != "ai21": + + streaming_response = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_sync_call, + client=None, + api_base=prepped.url, + headers=prepped.headers, # type: ignore + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + ), + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + + return streaming_response + ### COMPLETION + + if client is None or isinstance(client, AsyncHTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = HTTPHandler(**_params) # type: ignore + else: + client = client + try: + response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=response.text) + except httpx.TimeoutException: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + optional_params=optional_params, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + encoding=encoding, + ) + + def get_response_stream_shape(): from botocore.model import ServiceModel from botocore.loaders import Loader @@ -1086,6 +1858,31 @@ class AWSEventStreamDecoder: self.model = model self.parser = EventStreamJSONParser() + def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: + text = "" + tool_str = "" + is_finished = False + finish_reason = "" + usage: Optional[ConverseTokenUsageBlock] = None + if "delta" in chunk_data: + delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"]) + if "text" in delta_obj: + text = delta_obj["text"] + elif "toolUse" in delta_obj: + tool_str = delta_obj["toolUse"]["input"] + elif "stopReason" in chunk_data: + finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop")) + elif "usage" in chunk_data: + usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore + response = GenericStreamingChunk( + text=text, + tool_str=tool_str, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + ) + return response + def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: text = "" is_finished = False @@ -1098,19 +1895,8 @@ class AWSEventStreamDecoder: is_finished = True finish_reason = "stop" ######## bedrock.anthropic mappings ############### - elif "completion" in chunk_data: # not claude-3 - text = chunk_data["completion"] # bedrock.anthropic - stop_reason = chunk_data.get("stop_reason", None) - if stop_reason != None: - is_finished = True - finish_reason = stop_reason elif "delta" in chunk_data: - if chunk_data["delta"].get("text", None) is not None: - text = chunk_data["delta"]["text"] - stop_reason = chunk_data["delta"].get("stop_reason", None) - if stop_reason != None: - is_finished = True - finish_reason = stop_reason + return self.converse_chunk_parser(chunk_data=chunk_data) ######## bedrock.mistral mappings ############### elif "outputs" in chunk_data: if ( @@ -1137,11 +1923,11 @@ class AWSEventStreamDecoder: is_finished = True finish_reason = chunk_data["completionReason"] return GenericStreamingChunk( - **{ - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + text=text, + is_finished=is_finished, + finish_reason=finish_reason, + tool_str="", + usage=None, ) def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]: @@ -1178,9 +1964,14 @@ class AWSEventStreamDecoder: parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) if response_dict["status_code"] != 200: raise ValueError(f"Bad response code, expected 200: {response_dict}") + if "chunk" in parsed_response: + chunk = parsed_response.get("chunk") + if not chunk: + return None + return chunk.get("bytes").decode() # type: ignore[no-any-return] + else: + chunk = response_dict.get("body") + if not chunk: + return None - chunk = parsed_response.get("chunk") - if not chunk: - return None - - return chunk.get("bytes").decode() # type: ignore[no-any-return] + return chunk.decode() # type: ignore[no-any-return] diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index b91aaee2ae..5ec9c79bb2 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -156,12 +156,13 @@ class HTTPHandler: self, url: str, data: Optional[Union[dict, str]] = None, + json: Optional[Union[dict, str]] = None, params: Optional[dict] = None, headers: Optional[dict] = None, stream: bool = False, ): req = self.client.build_request( - "POST", url, data=data, params=params, headers=headers # type: ignore + "POST", url, data=data, json=json, params=params, headers=headers # type: ignore ) response = self.client.send(req, stream=stream) return response diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 10f3f16ed4..6bf03b52d4 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -3,14 +3,7 @@ import requests, traceback import json, re, xml.etree.ElementTree as ET from jinja2 import Template, exceptions, meta, BaseLoader from jinja2.sandbox import ImmutableSandboxedEnvironment -from typing import ( - Any, - List, - Mapping, - MutableMapping, - Optional, - Sequence, -) +from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple import litellm import litellm.types from litellm.types.completion import ( @@ -24,7 +17,7 @@ from litellm.types.completion import ( import litellm.types.llms from litellm.types.llms.anthropic import * import uuid - +from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock import litellm.types.llms.vertex_ai @@ -1460,9 +1453,7 @@ def _load_image_from_url(image_url): try: from PIL import Image except: - raise Exception( - "gemini image conversion failed please run `pip install Pillow`" - ) + raise Exception("image conversion failed please run `pip install Pillow`") from io import BytesIO try: @@ -1613,6 +1604,380 @@ def azure_text_pt(messages: list): return prompt +###### AMAZON BEDROCK ####### + +from litellm.types.llms.bedrock import ( + ToolResultContentBlock as BedrockToolResultContentBlock, + ToolResultBlock as BedrockToolResultBlock, + ToolConfigBlock as BedrockToolConfigBlock, + ToolUseBlock as BedrockToolUseBlock, + ImageSourceBlock as BedrockImageSourceBlock, + ImageBlock as BedrockImageBlock, + ContentBlock as BedrockContentBlock, + ToolInputSchemaBlock as BedrockToolInputSchemaBlock, + ToolSpecBlock as BedrockToolSpecBlock, + ToolBlock as BedrockToolBlock, + ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock, +) + + +def get_image_details(image_url) -> Tuple[str, str]: + try: + import base64 + + # Send a GET request to the image URL + response = requests.get(image_url) + response.raise_for_status() # Raise an exception for HTTP errors + + # Check the response's content type to ensure it is an image + content_type = response.headers.get("content-type") + if not content_type or "image" not in content_type: + raise ValueError( + f"URL does not point to a valid image (content-type: {content_type})" + ) + + # Convert the image content to base64 bytes + base64_bytes = base64.b64encode(response.content).decode("utf-8") + + # Get mime-type + mime_type = content_type.split("/")[ + 1 + ] # Extract mime-type from content-type header + + return base64_bytes, mime_type + + except requests.RequestException as e: + raise Exception(f"Request failed: {e}") + except Exception as e: + raise e + + +def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock: + if "base64" in image_url: + # Case 1: Images with base64 encoding + import base64, re + + # base 64 is passed as data:image/jpeg;base64, + image_metadata, img_without_base_64 = image_url.split(",") + + # read mime_type from img_without_base_64=data:image/jpeg;base64 + # Extract MIME type using regular expression + mime_type_match = re.match(r"data:(.*?);base64", image_metadata) + if mime_type_match: + mime_type = mime_type_match.group(1) + image_format = mime_type.split("/")[1] + else: + mime_type = "image/jpeg" + image_format = "jpeg" + _blob = BedrockImageSourceBlock(bytes=img_without_base_64) + supported_image_formats = ( + litellm.AmazonConverseConfig().get_supported_image_types() + ) + if image_format in supported_image_formats: + return BedrockImageBlock(source=_blob, format=image_format) # type: ignore + else: + # Handle the case when the image format is not supported + raise ValueError( + "Unsupported image format: {}. Supported formats: {}".format( + image_format, supported_image_formats + ) + ) + elif "https:/" in image_url: + # Case 2: Images with direct links + image_bytes, image_format = get_image_details(image_url) + _blob = BedrockImageSourceBlock(bytes=image_bytes) + supported_image_formats = ( + litellm.AmazonConverseConfig().get_supported_image_types() + ) + if image_format in supported_image_formats: + return BedrockImageBlock(source=_blob, format=image_format) # type: ignore + else: + # Handle the case when the image format is not supported + raise ValueError( + "Unsupported image format: {}. Supported formats: {}".format( + image_format, supported_image_formats + ) + ) + else: + raise ValueError( + "Unsupported image type. Expected either image url or base64 encoded string - \ + e.g. 'data:image/jpeg;base64,'" + ) + + +def _convert_to_bedrock_tool_call_invoke( + tool_calls: list, +) -> List[BedrockContentBlock]: + """ + OpenAI tool invokes: + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\n\"location\": \"Boston, MA\"\n}" + } + } + ] + }, + """ + """ + Bedrock tool invokes: + [ + { + "role": "assistant", + "toolUse": { + "input": {"location": "Boston, MA", ..}, + "name": "get_current_weather", + "toolUseId": "call_abc123" + } + } + ] + """ + """ + - json.loads argument + - extract name + - extract id + """ + + try: + _parts_list: List[BedrockContentBlock] = [] + for tool in tool_calls: + if "function" in tool: + id = tool["id"] + name = tool["function"].get("name", "") + arguments = tool["function"].get("arguments", "") + arguments_dict = json.loads(arguments) + bedrock_tool = BedrockToolUseBlock( + input=arguments_dict, name=name, toolUseId=id + ) + bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool) + _parts_list.append(bedrock_content_block) + return _parts_list + except Exception as e: + raise Exception( + "Unable to convert openai tool calls={} to bedrock tool calls. Received error={}".format( + tool_calls, str(e) + ) + ) + + +def _convert_to_bedrock_tool_call_result( + message: dict, +) -> BedrockMessageBlock: + """ + OpenAI message with a tool result looks like: + { + "tool_call_id": "tool_1", + "role": "tool", + "name": "get_current_weather", + "content": "function result goes here", + }, + + OpenAI message with a function call result looks like: + { + "role": "function", + "name": "get_current_weather", + "content": "function result goes here", + } + """ + """ + Bedrock result looks like this: + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "tooluse_kZJMlvQmRJ6eAyJE5GIl7Q", + "content": [ + { + "json": { + "song": "Elemental Hotel", + "artist": "8 Storey Hike" + } + } + ] + } + } + ] + } + """ + """ + - + """ + content = message.get("content", "") + name = message.get("name", "") + id = message.get("tool_call_id", str(uuid.uuid4())) + + tool_result_content_block = BedrockToolResultContentBlock(text=content) + tool_result = BedrockToolResultBlock( + content=[tool_result_content_block], + toolUseId=id, + ) + content_block = BedrockContentBlock(toolResult=tool_result) + + return BedrockMessageBlock(role="user", content=[content_block]) + + +def _bedrock_converse_messages_pt(messages: List) -> List[BedrockMessageBlock]: + """ + Converts given messages from OpenAI format to Bedrock format + + - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles) + - Please ensure that function response turn comes immediately after a function call turn + """ + + contents: List[BedrockMessageBlock] = [] + msg_i = 0 + while msg_i < len(messages): + user_content: List[BedrockContentBlock] = [] + init_msg_i = msg_i + ## MERGE CONSECUTIVE USER CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "user": + if isinstance(messages[msg_i]["content"], list): + _parts: List[BedrockContentBlock] = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text": + _part = BedrockContentBlock(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + _part = _process_bedrock_converse_image_block( # type: ignore + image_url=image_url + ) + _parts.append(BedrockContentBlock(image=_part)) # type: ignore + user_content.extend(_parts) + else: + _part = BedrockContentBlock(text=messages[msg_i]["content"]) + user_content.append(_part) + + msg_i += 1 + + if user_content: + contents.append(BedrockMessageBlock(role="user", content=user_content)) + assistant_content: List[BedrockContentBlock] = [] + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + if isinstance(messages[msg_i]["content"], list): + assistants_parts: List[BedrockContentBlock] = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text": + assistants_part = BedrockContentBlock(text=element["text"]) + assistants_parts.append(assistants_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + assistants_part = _process_bedrock_converse_image_block( # type: ignore + image_url=image_url + ) + assistants_parts.append( + BedrockContentBlock(image=assistants_part) # type: ignore + ) + assistant_content.extend(assistants_parts) + elif messages[msg_i].get( + "tool_calls", [] + ): # support assistant tool invoke convertion + assistant_content.extend( + _convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"]) + ) + else: + assistant_text = ( + messages[msg_i].get("content") or "" + ) # either string or none + if assistant_text: + assistant_content.append(BedrockContentBlock(text=assistant_text)) + + msg_i += 1 + + if assistant_content: + contents.append( + BedrockMessageBlock(role="assistant", content=assistant_content) + ) + + ## APPEND TOOL CALL MESSAGES ## + if msg_i < len(messages) and messages[msg_i]["role"] == "tool": + tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i]) + contents.append(tool_call_result) + msg_i += 1 + if msg_i == init_msg_i: # prevent infinite loops + raise Exception( + "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( + messages[msg_i] + ) + ) + + return contents + + +def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]: + """ + OpenAI tools looks like: + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ] + """ + """ + Bedrock toolConfig looks like: + "tools": [ + { + "toolSpec": { + "name": "top_song", + "description": "Get the most popular song played on a radio station.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "sign": { + "type": "string", + "description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP." + } + }, + "required": [ + "sign" + ] + } + } + } + } + ] + """ + tool_block_list: List[BedrockToolBlock] = [] + for tool in tools: + parameters = tool.get("function", {}).get("parameters", None) + name = tool.get("function", {}).get("name", "") + description = tool.get("function", {}).get("description", "") + tool_input_schema = BedrockToolInputSchemaBlock(json=parameters) + tool_spec = BedrockToolSpecBlock( + inputSchema=tool_input_schema, name=name, description=description + ) + tool_block = BedrockToolBlock(toolSpec=tool_spec) + tool_block_list.append(tool_block) + + return tool_block_list + + # Function call template def function_call_prompt(messages: list, functions: list): function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:""" diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index aabe084b8e..bd9cfaa8d6 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -12,6 +12,7 @@ from litellm.llms.prompt_templates.factory import ( convert_to_gemini_tool_call_result, convert_to_gemini_tool_call_invoke, ) +from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type class VertexAIError(Exception): @@ -297,29 +298,31 @@ def _convert_gemini_role(role: str) -> Literal["user", "model"]: def _process_gemini_image(image_url: str) -> PartType: try: - if ".mp4" in image_url and "gs://" in image_url: - # Case 1: Videos with Cloud Storage URIs - part_mime = "video/mp4" - _file_data = FileDataType(mime_type=part_mime, file_uri=image_url) - return PartType(file_data=_file_data) - elif ".pdf" in image_url and "gs://" in image_url: - # Case 2: PDF's with Cloud Storage URIs - part_mime = "application/pdf" - _file_data = FileDataType(mime_type=part_mime, file_uri=image_url) - return PartType(file_data=_file_data) - elif "gs://" in image_url: - # Case 3: Images with Cloud Storage URIs - # The supported MIME types for images include image/png and image/jpeg. - part_mime = "image/png" if "png" in image_url else "image/jpeg" - _file_data = FileDataType(mime_type=part_mime, file_uri=image_url) - return PartType(file_data=_file_data) + # GCS URIs + if "gs://" in image_url: + # Figure out file type + extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png" + extension = extension_with_dot[1:] # Ex: "png" + + file_type = get_file_type_from_extension(extension) + + # Validate the file type is supported by Gemini + if not is_gemini_1_5_accepted_file_type(file_type): + raise Exception(f"File type not supported by gemini - {file_type}") + + mime_type = get_file_mime_type_for_file_type(file_type) + file_data = FileDataType(mime_type=mime_type, file_uri=image_url) + + return PartType(file_data=file_data) + + # Direct links elif "https:/" in image_url: - # Case 4: Images with direct links image = _load_image_from_url(image_url) _blob = BlobType(data=image.data, mime_type=image._mime_type) return PartType(inline_data=_blob) + + # Base64 encoding elif "base64" in image_url: - # Case 5: Images with base64 encoding import base64, re # base 64 is passed as data:image/jpeg;base64, @@ -426,112 +429,6 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]: return contents -def _gemini_vision_convert_messages(messages: list): - """ - Converts given messages for GPT-4 Vision to Gemini format. - - Args: - messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type: - - If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt. - - If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images. - - Returns: - tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images). - - Raises: - VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed. - Exception: If any other exception occurs during the execution of the function. - - Note: - This function is based on the code from the 'gemini/getting-started/intro_gemini_python.ipynb' notebook in the 'generative-ai' repository on GitHub. - The supported MIME types for images include 'image/png' and 'image/jpeg'. - - Examples: - >>> messages = [ - ... {"content": "Hello, world!"}, - ... {"content": [{"type": "text", "text": "This is a text message."}, {"type": "image_url", "image_url": "example.com/image.png"}]}, - ... ] - >>> _gemini_vision_convert_messages(messages) - ('Hello, world!This is a text message.', [, ]) - """ - try: - import vertexai - except: - raise VertexAIError( - status_code=400, - message="vertexai import failed please run `pip install google-cloud-aiplatform`", - ) - try: - from vertexai.preview.language_models import ( - ChatModel, - CodeChatModel, - InputOutputTextPair, - ) - from vertexai.language_models import TextGenerationModel, CodeGenerationModel - from vertexai.preview.generative_models import ( - GenerativeModel, - Part, - GenerationConfig, - Image, - ) - - # given messages for gpt-4 vision, convert them for gemini - # https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb - prompt = "" - images = [] - for message in messages: - if isinstance(message["content"], str): - prompt += message["content"] - elif isinstance(message["content"], list): - # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models - for element in message["content"]: - if isinstance(element, dict): - if element["type"] == "text": - prompt += element["text"] - elif element["type"] == "image_url": - image_url = element["image_url"]["url"] - images.append(image_url) - # processing images passed to gemini - processed_images = [] - for img in images: - if "gs://" in img: - # Case 1: Images with Cloud Storage URIs - # The supported MIME types for images include image/png and image/jpeg. - part_mime = "image/png" if "png" in img else "image/jpeg" - google_clooud_part = Part.from_uri(img, mime_type=part_mime) - processed_images.append(google_clooud_part) - elif "https:/" in img: - # Case 2: Images with direct links - image = _load_image_from_url(img) - processed_images.append(image) - elif ".mp4" in img and "gs://" in img: - # Case 3: Videos with Cloud Storage URIs - part_mime = "video/mp4" - google_clooud_part = Part.from_uri(img, mime_type=part_mime) - processed_images.append(google_clooud_part) - elif "base64" in img: - # Case 4: Images with base64 encoding - import base64, re - - # base 64 is passed as data:image/jpeg;base64, - image_metadata, img_without_base_64 = img.split(",") - - # read mime_type from img_without_base_64=data:image/jpeg;base64 - # Extract MIME type using regular expression - mime_type_match = re.match(r"data:(.*?);base64", image_metadata) - - if mime_type_match: - mime_type = mime_type_match.group(1) - else: - mime_type = "image/jpeg" - decoded_img = base64.b64decode(img_without_base_64) - processed_image = Part.from_data(data=decoded_img, mime_type=mime_type) - processed_images.append(processed_image) - return prompt, processed_images - except Exception as e: - raise e - - def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str): _cache_key = f"{model}-{vertex_project}-{vertex_location}" return _cache_key diff --git a/litellm/main.py b/litellm/main.py index 5ea6957c02..b62cb78cfd 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic_text import AnthropicTextCompletion from .llms.huggingface_restapi import Huggingface from .llms.predibase import PredibaseChatCompletion -from .llms.bedrock_httpx import BedrockLLM +from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM from .llms.vertex_httpx import VertexLLM from .llms.triton import TritonChatCompletion from .llms.prompt_templates.factory import ( @@ -122,6 +122,7 @@ huggingface = Huggingface() predibase_chat_completions = PredibaseChatCompletion() triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() +bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() ####### COMPLETION ENDPOINTS ################ @@ -2107,22 +2108,40 @@ def completion( logging_obj=logging, ) else: - response = bedrock_chat_completion.completion( - model=model, - messages=messages, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - extra_headers=extra_headers, - timeout=timeout, - acompletion=acompletion, - client=client, - ) + if model.startswith("anthropic"): + response = bedrock_converse_chat_completion.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + acompletion=acompletion, + client=client, + ) + else: + response = bedrock_chat_completion.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + acompletion=acompletion, + client=client, + ) if optional_params.get("stream", False): ## LOGGING logging.post_call( diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index e6c6e6fad2..6409773e38 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1,11 +1,20 @@ from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict from dataclasses import fields import enum -from typing import Optional, List, Union, Dict, Literal, Any, TypedDict +from typing import Optional, List, Union, Dict, Literal, Any, TypedDict, TYPE_CHECKING from datetime import datetime import uuid, json, sys, os from litellm.types.router import UpdateRouterConfig from litellm.types.utils import ProviderField +from typing_extensions import Annotated + + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any class LitellmUserRoles(str, enum.Enum): @@ -1195,6 +1204,7 @@ class UserAPIKeyAuth( ] ] = None allowed_model_region: Optional[Literal["eu"]] = None + parent_otel_span: Optional[Span] = None @model_validator(mode="before") @classmethod @@ -1207,6 +1217,9 @@ class UserAPIKeyAuth( values.update({"api_key": hash_token(values.get("api_key"))}) return values + class Config: + arbitrary_types_allowed = True + class LiteLLM_Config(LiteLLMBase): param_name: str diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 969ef8831d..17f0822e61 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -17,10 +17,19 @@ from litellm.proxy._types import ( LiteLLM_OrganizationTable, LitellmUserRoles, ) -from typing import Optional, Literal, Union -from litellm.proxy.utils import PrismaClient +from typing import Optional, Literal, TYPE_CHECKING, Any +from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry from litellm.caching import DualCache import litellm +from litellm.types.services import ServiceLoggerPayload, ServiceTypes +from datetime import datetime + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value @@ -216,10 +225,13 @@ def get_actual_routes(allowed_routes: list) -> list: return actual_routes +@log_to_opentelemetry async def get_end_user_object( end_user_id: Optional[str], prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, ) -> Optional[LiteLLM_EndUserTable]: """ Returns end user object, if in db. @@ -279,11 +291,14 @@ async def get_end_user_object( return None +@log_to_opentelemetry async def get_user_object( user_id: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, user_id_upsert: bool, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, ) -> Optional[LiteLLM_UserTable]: """ - Check if user id in proxy User Table @@ -330,10 +345,13 @@ async def get_user_object( ) +@log_to_opentelemetry async def get_team_object( team_id: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, ) -> LiteLLM_TeamTable: """ - Check if team id in proxy Team Table @@ -372,10 +390,13 @@ async def get_team_object( ) +@log_to_opentelemetry async def get_org_object( org_id: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, ): """ - Check if org id in proxy Org Table diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py new file mode 100644 index 0000000000..945799b4cf --- /dev/null +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -0,0 +1,130 @@ +import copy +from fastapi import Request +from typing import Any, Dict, Optional, TYPE_CHECKING +from litellm.proxy._types import UserAPIKeyAuth +from litellm._logging import verbose_proxy_logger, verbose_logger + +if TYPE_CHECKING: + from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig + + ProxyConfig = _ProxyConfig +else: + ProxyConfig = Any + + +def parse_cache_control(cache_control): + cache_dict = {} + directives = cache_control.split(", ") + + for directive in directives: + if "=" in directive: + key, value = directive.split("=") + cache_dict[key] = value + else: + cache_dict[directive] = True + + return cache_dict + + +async def add_litellm_data_to_request( + data: dict, + request: Request, + user_api_key_dict: UserAPIKeyAuth, + proxy_config: ProxyConfig, + general_settings: Optional[Dict[str, Any]] = None, + version: Optional[str] = None, +): + """ + Adds LiteLLM-specific data to the request. + + Args: + data (dict): The data dictionary to be modified. + request (Request): The incoming request. + user_api_key_dict (UserAPIKeyAuth): The user API key dictionary. + general_settings (Optional[Dict[str, Any]], optional): General settings. Defaults to None. + version (Optional[str], optional): Version. Defaults to None. + + Returns: + dict: The modified data dictionary. + + """ + query_params = dict(request.query_params) + if "api-version" in query_params: + data["api_version"] = query_params["api-version"] + + # Include original request and headers in the data + data["proxy_server_request"] = { + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + ## Cache Controls + headers = request.headers + verbose_proxy_logger.debug("Request Headers: %s", headers) + cache_control_header = headers.get("Cache-Control", None) + if cache_control_header: + cache_dict = parse_cache_control(cache_control_header) + data["ttl"] = cache_dict.get("s-maxage") + + verbose_proxy_logger.debug("receiving data: %s", data) + # users can pass in 'user' param to /chat/completions. Don't override it + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + # if users are using user_api_key_auth, set `user` in `data` + data["user"] = user_api_key_dict.user_id + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_alias"] = getattr( + user_api_key_dict, "key_alias", None + ) + data["metadata"]["user_api_end_user_max_budget"] = getattr( + user_api_key_dict, "end_user_max_budget", None + ) + data["metadata"]["litellm_api_version"] = version + + if general_settings is not None: + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) + + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id + data["metadata"]["user_api_key_team_id"] = getattr( + user_api_key_dict, "team_id", None + ) + data["metadata"]["user_api_key_team_alias"] = getattr( + user_api_key_dict, "team_alias", None + ) + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["metadata"]["headers"] = _headers + data["metadata"]["endpoint"] = str(request.url) + # Add the OTEL Parent Trace before sending it LiteLLM + data["metadata"]["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span + + ### END-USER SPECIFIC PARAMS ### + if user_api_key_dict.allowed_model_region is not None: + data["allowed_model_region"] = user_api_key_dict.allowed_model_region + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + return data diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 2f2d8545dd..df6dfd1394 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -21,10 +21,14 @@ model_list: general_settings: master_key: sk-1234 + alerting: ["slack"] + +litellm_settings: + callbacks: ["otel"] + store_audit_logs: true + redact_messages_in_exceptions: True enforced_params: - user - metadata - metadata.generation_name -litellm_settings: - store_audit_logs: true \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0d8983db43..564f886f19 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2,12 +2,19 @@ import sys, os, platform, time, copy, re, asyncio, inspect import threading, ast import shutil, random, traceback, requests from datetime import datetime, timedelta, timezone -from typing import Optional, List, Callable, get_args, Set +from typing import Optional, List, Callable, get_args, Set, Any, TYPE_CHECKING import secrets, subprocess import hashlib, uuid import warnings import importlib +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + def showwarning(message, category, filename, lineno, file=None, line=None): traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n" @@ -82,6 +89,7 @@ import litellm from litellm.types.llms.openai import ( HttpxBinaryResponseContent, ) +from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request from litellm.proxy.utils import ( PrismaClient, DBClient, @@ -103,6 +111,7 @@ from litellm.proxy.utils import ( update_spend, encrypt_value, decrypt_value, + _to_ns, get_error_message_str, ) from litellm import ( @@ -405,6 +414,7 @@ disable_spend_logs = False jwt_handler = JWTHandler() prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None store_model_in_db: bool = False +open_telemetry_logger = None ### INITIALIZE GLOBAL LOGGING OBJECT ### proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) ### REDIS QUEUE ### @@ -500,12 +510,17 @@ async def check_request_disconnection(request: Request, llm_api_call_task): async def user_api_key_auth( request: Request, api_key: str = fastapi.Security(api_key_header) ) -> UserAPIKeyAuth: - global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client, general_settings + global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client, general_settings, proxy_logging_obj try: if isinstance(api_key, str): passed_in_key = api_key api_key = _get_bearer_token(api_key=api_key) - + parent_otel_span: Optional[Span] = None + if open_telemetry_logger is not None: + parent_otel_span = open_telemetry_logger.tracer.start_span( + name="Received Proxy Server Request", + start_time=_to_ns(datetime.now()), + ) ### USER-DEFINED AUTH FUNCTION ### if user_custom_auth is not None: response = await user_custom_auth(request=request, api_key=api_key) @@ -548,7 +563,10 @@ async def user_api_key_auth( litellm_proxy_roles=jwt_handler.litellm_jwtauth, ) if is_allowed: - return UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) + return UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, + ) else: allowed_routes = ( jwt_handler.litellm_jwtauth.admin_allowed_routes @@ -587,6 +605,8 @@ async def user_api_key_auth( team_id=team_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, ) # [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable` @@ -598,6 +618,8 @@ async def user_api_key_auth( org_id=org_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, ) # [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable` user_object = None @@ -611,6 +633,8 @@ async def user_api_key_auth( prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, user_id_upsert=jwt_handler.is_upsert_user_id(), + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, ) # [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable` @@ -624,6 +648,8 @@ async def user_api_key_auth( end_user_id=end_user_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, ) global_proxy_spend = None @@ -657,7 +683,6 @@ async def user_api_key_auth( user_info=user_info, ) ) - # get the request body request_data = await _read_request_body(request=request) @@ -686,15 +711,21 @@ async def user_api_key_auth( user_role=LitellmUserRoles.INTERNAL_USER, user_id=user_id, org_id=org_id, + parent_otel_span=parent_otel_span, ) #### ELSE #### if master_key is None: if isinstance(api_key, str): return UserAPIKeyAuth( - api_key=api_key, user_role=LitellmUserRoles.PROXY_ADMIN + api_key=api_key, + user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, ) else: - return UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) + return UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, + ) elif api_key is None: # only require api key if master key is set raise Exception("No api key passed in.") elif api_key == "": @@ -722,6 +753,8 @@ async def user_api_key_auth( end_user_id=request_data["user"], prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, ) if _end_user_object is not None: end_user_params["allowed_model_region"] = ( @@ -770,6 +803,7 @@ async def user_api_key_auth( valid_token.allowed_model_region = end_user_params.get( "allowed_model_region" ) + valid_token.parent_otel_span = parent_otel_span return valid_token @@ -796,6 +830,7 @@ async def user_api_key_auth( api_key=master_key, user_role=LitellmUserRoles.PROXY_ADMIN, user_id=litellm_proxy_admin_name, + parent_otel_span=parent_otel_span, **end_user_params, ) await user_api_key_cache.async_set_cache( @@ -834,7 +869,10 @@ async def user_api_key_auth( verbose_proxy_logger.debug("api key: %s", api_key) if prisma_client is not None: _valid_token: Optional[BaseModel] = await prisma_client.get_data( - token=api_key, table_name="combined_view" + token=api_key, + table_name="combined_view", + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, ) if _valid_token is not None: valid_token = UserAPIKeyAuth( @@ -1438,6 +1476,7 @@ async def user_api_key_auth( return UserAPIKeyAuth( api_key=api_key, user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, **valid_token_dict, ) elif ( @@ -1445,7 +1484,10 @@ async def user_api_key_auth( and route in LiteLLMRoutes.sso_only_routes.value ): return UserAPIKeyAuth( - api_key=api_key, user_role="app_owner", **valid_token_dict + api_key=api_key, + user_role="app_owner", + parent_otel_span=parent_otel_span, + **valid_token_dict, ) else: raise Exception( @@ -1461,18 +1503,21 @@ async def user_api_key_auth( return UserAPIKeyAuth( api_key=api_key, user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, **valid_token_dict, ) elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value: return UserAPIKeyAuth( api_key=api_key, user_role=LitellmUserRoles.INTERNAL_USER, + parent_otel_span=parent_otel_span, **valid_token_dict, ) else: return UserAPIKeyAuth( api_key=api_key, user_role=LitellmUserRoles.INTERNAL_USER, + parent_otel_span=parent_otel_span, **valid_token_dict, ) else: @@ -2318,7 +2363,7 @@ class ProxyConfig: """ Load config values into proxy global state """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger # Load existing config config = await self.get_config(config_file_path=config_file_path) @@ -2442,7 +2487,9 @@ class ProxyConfig: OpenTelemetry, ) - imported_list.append(OpenTelemetry()) + open_telemetry_logger = OpenTelemetry() + + imported_list.append(open_telemetry_logger) elif isinstance(callback, str) and callback == "presidio": from litellm.proxy.hooks.presidio_pii_masking import ( _OPTIONAL_PresidioPIIMasking, @@ -3821,20 +3868,6 @@ def get_litellm_model_info(model: dict = {}): return {} -def parse_cache_control(cache_control): - cache_dict = {} - directives = cache_control.split(", ") - - for directive in directives: - if "=" in directive: - key, value = directive.split("=") - cache_dict[key] = value - else: - cache_dict[directive] = True - - return cache_dict - - def on_backoff(details): # The 'tries' key in the details dictionary contains the number of completed tries verbose_proxy_logger.debug("Backing off... this was attempt # %s", details["tries"]) @@ -4156,28 +4189,15 @@ async def chat_completion( except: data = json.loads(body_str) - # Azure OpenAI only: check if user passed api-version - query_params = dict(request.query_params) - if "api-version" in query_params: - data["api_version"] = query_params["api-version"] + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) - # Include original request and headers in the data - data["proxy_server_request"] = { - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - ## Cache Controls - headers = request.headers - verbose_proxy_logger.debug("Request Headers: %s", headers) - cache_control_header = headers.get("Cache-Control", None) - if cache_control_header: - cache_dict = parse_cache_control(cache_control_header) - data["ttl"] = cache_dict.get("s-maxage") - - verbose_proxy_logger.debug("receiving data: %s", data) data["model"] = ( general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args @@ -4185,63 +4205,6 @@ async def chat_completion( or data["model"] # default passed in http request ) - # users can pass in 'user' param to /chat/completions. Don't override it - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - # if users are using user_api_key_auth, set `user` in `data` - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None - ) - data["metadata"]["user_api_end_user_max_budget"] = getattr( - user_api_key_dict, "end_user_max_budget", None - ) - data["metadata"]["litellm_api_version"] = version - - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - _is_valid_team_configs( - team_id=team_id, team_config=team_config, request_data=data - ) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call - - ### END-USER SPECIFIC PARAMS ### - if user_api_key_dict.allowed_model_region is not None: - data["allowed_model_region"] = user_api_key_dict.allowed_model_region - global user_temperature, user_request_timeout, user_max_tokens, user_api_base # override with user settings, these are params passed via cli if user_temperature: @@ -4500,7 +4463,6 @@ async def completion( except: data = json.loads(body_str) - data["user"] = data.get("user", user_api_key_dict.user_id) data["model"] = ( general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args @@ -4509,30 +4471,15 @@ async def completion( ) if user_model: data["model"] = user_model - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["endpoint"] = str(request.url) # override with user settings, these are params passed via cli if user_temperature: @@ -4729,15 +4676,14 @@ async def embeddings( data = orjson.loads(body) # Include original request and headers in the data - data["proxy_server_request"] = { - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) data["model"] = ( general_settings.get("embedding_model", None) # server default @@ -4747,45 +4693,6 @@ async def embeddings( ) if user_model: data["model"] = user_model - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call ### MODEL ALIAS MAPPING ### # check if model name in model alias map @@ -4945,15 +4852,14 @@ async def image_generation( data = orjson.loads(body) # Include original request and headers in the data - data["proxy_server_request"] = { - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) data["model"] = ( general_settings.get("image_generation_model", None) # server default @@ -4963,46 +4869,6 @@ async def image_generation( if user_model: data["model"] = user_model - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call - ### MODEL ALIAS MAPPING ### # check if model name in model alias map # get the actual model name @@ -5132,12 +4998,14 @@ async def audio_speech( data = orjson.loads(body) # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) if data.get("user", None) is None and user_api_key_dict.user_id is not None: data["user"] = user_api_key_dict.user_id @@ -5145,46 +5013,6 @@ async def audio_speech( if user_model: data["model"] = user_model - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None - ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call - router_model_names = llm_router.model_names if llm_router is not None else [] ### CALL HOOKS ### - modify incoming data / reject request before calling the model @@ -5302,12 +5130,14 @@ async def audio_transcriptions( data = {key: value for key, value in form_data.items() if key != "file"} # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) if data.get("user", None) is None and user_api_key_dict.user_id is not None: data["user"] = user_api_key_dict.user_id @@ -5320,47 +5150,6 @@ async def audio_transcriptions( if user_model: data["model"] = user_model - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None - ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - data["metadata"]["file_name"] = file.filename - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call - router_model_names = llm_router.model_names if llm_router is not None else [] assert ( @@ -5516,55 +5305,14 @@ async def get_assistants( body = await request.body() # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - data["metadata"]["litellm_api_version"] = version - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch if llm_router is None: @@ -5649,55 +5397,14 @@ async def create_threads( body = await request.body() # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "litellm_metadata" not in data: - data["litellm_metadata"] = {} - data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["litellm_metadata"]["headers"] = _headers - data["litellm_metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["litellm_metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["litellm_metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["litellm_metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["litellm_metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch if llm_router is None: @@ -5781,55 +5488,14 @@ async def get_thread( try: # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch if llm_router is None: @@ -5916,55 +5582,14 @@ async def add_messages( data = orjson.loads(body) # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "litellm_metadata" not in data: - data["litellm_metadata"] = {} - data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key - data["litellm_metadata"]["litellm_api_version"] = version - data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["litellm_metadata"]["headers"] = _headers - data["litellm_metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["litellm_metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["litellm_metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["litellm_metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["litellm_metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch if llm_router is None: @@ -6047,55 +5672,14 @@ async def get_messages( data: Dict = {} try: # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch if llm_router is None: @@ -6180,55 +5764,14 @@ async def run_thread( body = await request.body() data = orjson.loads(body) # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "litellm_metadata" not in data: - data["litellm_metadata"] = {} - data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key - data["litellm_metadata"]["litellm_api_version"] = version - data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["litellm_metadata"]["headers"] = _headers - data["litellm_metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["litellm_metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["litellm_metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["litellm_metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["litellm_metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch if llm_router is None: @@ -6344,55 +5887,14 @@ async def create_batch( data = {key: value for key, value in form_data.items() if key != "file"} # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call _create_batch_data = CreateBatchRequest(**data) @@ -6489,55 +5991,14 @@ async def retrieve_batch( data = {key: value for key, value in form_data.items() if key != "file"} # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call _retrieve_batch_request = RetrieveBatchRequest( batch_id=batch_id, @@ -6649,55 +6110,14 @@ async def create_file( data = {key: value for key, value in form_data.items() if key != "file"} # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call _create_file_request = CreateFileRequest() @@ -6791,15 +6211,14 @@ async def moderations( data = orjson.loads(body) # Include original request and headers in the data - data["proxy_server_request"] = { - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) data["model"] = ( general_settings.get("moderation_model", None) # server default @@ -6809,46 +6228,6 @@ async def moderations( if user_model: data["model"] = user_model - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None - ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call - router_model_names = llm_router.model_names if llm_router is not None else [] ### CALL HOOKS ### - modify incoming data / reject request before calling the model diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 423f0f2d6d..79824948e3 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Any, Literal, Union +from typing import Optional, List, Any, Literal, Union, TYPE_CHECKING import os import subprocess import hashlib @@ -46,6 +46,15 @@ from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart from datetime import datetime, timedelta from litellm.integrations.slack_alerting import SlackAlerting +from typing_extensions import overload +from functools import wraps + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any def print_verbose(print_statement): @@ -63,6 +72,58 @@ def print_verbose(print_statement): print(f"LiteLLM Proxy: {print_statement}") # noqa +def safe_deep_copy(data): + """ + Safe Deep Copy + + The LiteLLM Request has some object that can-not be pickled / deep copied + + Use this function to safely deep copy the LiteLLM Request + """ + + # Step 1: Remove the litellm_parent_otel_span + if isinstance(data, dict): + # remove litellm_parent_otel_span since this is not picklable + if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]: + litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span") + new_data = copy.deepcopy(data) + + # Step 2: re-add the litellm_parent_otel_span after doing a deep copy + if isinstance(data, dict): + if "metadata" in data: + data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span + return new_data + + +def log_to_opentelemetry(func): + @wraps(func) + async def wrapper(*args, **kwargs): + start_time = datetime.now() + result = await func(*args, **kwargs) + end_time = datetime.now() + + # Log to OTEL only if "parent_otel_span" is in kwargs and is not None + if ( + "parent_otel_span" in kwargs + and kwargs["parent_otel_span"] is not None + and "proxy_logging_obj" in kwargs + and kwargs["proxy_logging_obj"] is not None + ): + proxy_logging_obj = kwargs["proxy_logging_obj"] + await proxy_logging_obj.service_logging_obj.async_service_success_hook( + service=ServiceTypes.DB, + call_type=func.__name__, + parent_otel_span=kwargs["parent_otel_span"], + duration=0.0, + start_time=start_time, + end_time=end_time, + ) + # end of logging to otel + return result + + return wrapper + + ### LOGGING ### class ProxyLogging: """ @@ -282,7 +343,7 @@ class ProxyLogging: """ Runs the CustomLogger's async_moderation_hook() """ - new_data = copy.deepcopy(data) + new_data = safe_deep_copy(data) for callback in litellm.callbacks: try: if isinstance(callback, CustomLogger): @@ -832,6 +893,7 @@ class PrismaClient: max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) + @log_to_opentelemetry async def get_data( self, token: Optional[Union[str, list]] = None, @@ -858,6 +920,8 @@ class PrismaClient: limit: Optional[ int ] = None, # pagination, number of rows to getch when find_all==True + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, ): args_passed_in = locals() start_time = time.time() @@ -2829,6 +2893,10 @@ missing_keys_html_form = """ """ +def _to_ns(dt): + return int(dt.timestamp() * 1e9) + + def get_error_message_str(e: Exception) -> str: error_message = "" if isinstance(e, HTTPException): diff --git a/litellm/py.typed b/litellm/py.typed new file mode 100644 index 0000000000..5686005abc --- /dev/null +++ b/litellm/py.typed @@ -0,0 +1,2 @@ +# Marker file to instruct type checkers to look for inline type annotations in this package. +# See PEP 561 for more information. diff --git a/litellm/tests/log.txt b/litellm/tests/log.txt deleted file mode 100644 index ea07ca7e12..0000000000 --- a/litellm/tests/log.txt +++ /dev/null @@ -1,4274 +0,0 @@ -============================= test session starts ============================== -platform darwin -- Python 3.11.4, pytest-8.2.0, pluggy-1.5.0 -rootdir: /Users/krrishdholakia/Documents/litellm -configfile: pyproject.toml -plugins: asyncio-0.23.6, mock-3.14.0, anyio-4.2.0 -asyncio: mode=Mode.STRICT -collected 1 item - -test_amazing_vertex_completion.py F [100%] - -=================================== FAILURES =================================== -____________________________ test_gemini_pro_vision ____________________________ - -model = 'gemini-1.5-flash-preview-0514' -messages = [{'content': [{'text': 'Whats in this image?', 'type': 'text'}, {'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}, 'type': 'image_url'}], 'role': 'user'}] -model_response = ModelResponse(id='chatcmpl-722df0e7-4e2d-44e6-9e2c-49823faa0189', choices=[Choices(finish_reason='stop', index=0, mess... role='assistant'))], created=1716145725, model=None, object='chat.completion', system_fingerprint=None, usage=Usage()) -print_verbose = -encoding = -logging_obj = -vertex_project = None, vertex_location = None, vertex_credentials = None -optional_params = {} -litellm_params = {'acompletion': False, 'api_base': '', 'api_key': None, 'completion_call_id': None, ...} -logger_fn = None, acompletion = False - - def completion( - model: str, - messages: list, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - vertex_project=None, - vertex_location=None, - vertex_credentials=None, - optional_params=None, - litellm_params=None, - logger_fn=None, - acompletion: bool = False, - ): - try: - import vertexai - except: - raise VertexAIError( - status_code=400, - message="vertexai import failed please run `pip install google-cloud-aiplatform`", - ) - - if not ( - hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") - ): - raise VertexAIError( - status_code=400, - message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", - ) - try: - from vertexai.preview.language_models import ( - ChatModel, - CodeChatModel, - InputOutputTextPair, - ) - from vertexai.language_models import TextGenerationModel, CodeGenerationModel - from vertexai.preview.generative_models import ( - GenerativeModel, - Part, - GenerationConfig, - ) - from google.cloud import aiplatform # type: ignore - from google.protobuf import json_format # type: ignore - from google.protobuf.struct_pb2 import Value # type: ignore - from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore - import google.auth # type: ignore - import proto # type: ignore - - ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 - print_verbose( - f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" - ) - if vertex_credentials is not None and isinstance(vertex_credentials, str): - import google.oauth2.service_account - - json_obj = json.loads(vertex_credentials) - - creds = google.oauth2.service_account.Credentials.from_service_account_info( - json_obj, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - else: - creds, _ = google.auth.default(quota_project_id=vertex_project) - print_verbose( - f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" - ) - vertexai.init( - project=vertex_project, location=vertex_location, credentials=creds - ) - - ## Load Config - config = litellm.VertexAIConfig.get_config() - for k, v in config.items(): - if k not in optional_params: - optional_params[k] = v - - ## Process safety settings into format expected by vertex AI - safety_settings = None - if "safety_settings" in optional_params: - safety_settings = optional_params.pop("safety_settings") - if not isinstance(safety_settings, list): - raise ValueError("safety_settings must be a list") - if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict): - raise ValueError("safety_settings must be a list of dicts") - safety_settings = [ - gapic_content_types.SafetySetting(x) for x in safety_settings - ] - - # vertexai does not use an API key, it looks for credentials.json in the environment - - prompt = " ".join( - [ - message["content"] - for message in messages - if isinstance(message["content"], str) - ] - ) - - mode = "" - - request_str = "" - response_obj = None - async_client = None - instances = None - client_options = { - "api_endpoint": f"{vertex_location}-aiplatform.googleapis.com" - } - if ( - model in litellm.vertex_language_models - or model in litellm.vertex_vision_models - ): - llm_model = GenerativeModel(model) - mode = "vision" - request_str += f"llm_model = GenerativeModel({model})\n" - elif model in litellm.vertex_chat_models: - llm_model = ChatModel.from_pretrained(model) - mode = "chat" - request_str += f"llm_model = ChatModel.from_pretrained({model})\n" - elif model in litellm.vertex_text_models: - llm_model = TextGenerationModel.from_pretrained(model) - mode = "text" - request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n" - elif model in litellm.vertex_code_text_models: - llm_model = CodeGenerationModel.from_pretrained(model) - mode = "text" - request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" - elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models - llm_model = CodeChatModel.from_pretrained(model) - mode = "chat" - request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" - elif model == "private": - mode = "private" - model = optional_params.pop("model_id", None) - # private endpoint requires a dict instead of JSON - instances = [optional_params.copy()] - instances[0]["prompt"] = prompt - llm_model = aiplatform.PrivateEndpoint( - endpoint_name=model, - project=vertex_project, - location=vertex_location, - ) - request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n" - else: # assume vertex model garden on public endpoint - mode = "custom" - - instances = [optional_params.copy()] - instances[0]["prompt"] = prompt - instances = [ - json_format.ParseDict(instance_dict, Value()) - for instance_dict in instances - ] - # Will determine the API used based on async parameter - llm_model = None - - # NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now - if acompletion == True: - data = { - "llm_model": llm_model, - "mode": mode, - "prompt": prompt, - "logging_obj": logging_obj, - "request_str": request_str, - "model": model, - "model_response": model_response, - "encoding": encoding, - "messages": messages, - "print_verbose": print_verbose, - "client_options": client_options, - "instances": instances, - "vertex_location": vertex_location, - "vertex_project": vertex_project, - "safety_settings": safety_settings, - **optional_params, - } - if optional_params.get("stream", False) is True: - # async streaming - return async_streaming(**data) - - return async_completion(**data) - - if mode == "vision": - print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call") - print_verbose(f"\nProcessing input messages = {messages}") - tools = optional_params.pop("tools", None) - content = _gemini_convert_messages_text(messages=messages) - stream = optional_params.pop("stream", False) - if stream == True: - request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - model_response = llm_model.generate_content( - contents={"content": content}, - generation_config=optional_params, - safety_settings=safety_settings, - stream=True, - tools=tools, - ) - - return model_response - - request_str += f"response = llm_model.generate_content({content})\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - ## LLM Call -> response = llm_model.generate_content( - contents=content, - generation_config=optional_params, - safety_settings=safety_settings, - tools=tools, - ) - -../llms/vertex_ai.py:740: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../proxy/myenv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py:405: in generate_content - return self._generate_content( -../proxy/myenv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py:487: in _generate_content - request = self._prepare_request( -../proxy/myenv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py:274: in _prepare_request - contents = [ -../proxy/myenv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py:275: in - gapic_content_types.Content(content_dict) for content_dict in contents -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -self = <[AttributeError('Unknown field for Content: _pb') raised in repr()] Content object at 0x1646aaa90> -mapping = {'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -], 'role': 'user'} -ignore_unknown_fields = False, kwargs = {} -params = {'parts': [text: "Whats in this image?" -, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -], 'role': 'user'} -marshal = , key = 'parts' -value = [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -] -pb_value = [text: "Whats in this image?" -, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -] - - def __init__( - self, - mapping=None, - *, - ignore_unknown_fields=False, - **kwargs, - ): - # We accept several things for `mapping`: - # * An instance of this class. - # * An instance of the underlying protobuf descriptor class. - # * A dict - # * Nothing (keyword arguments only). - if mapping is None: - if not kwargs: - # Special fast path for empty construction. - super().__setattr__("_pb", self._meta.pb()) - return - - mapping = kwargs - elif isinstance(mapping, self._meta.pb): - # Make a copy of the mapping. - # This is a constructor for a new object, so users will assume - # that it will not have side effects on the arguments being - # passed in. - # - # The `wrap` method on the metaclass is the public API for taking - # ownership of the passed in protobuf object. - mapping = copy.deepcopy(mapping) - if kwargs: - mapping.MergeFrom(self._meta.pb(**kwargs)) - - super().__setattr__("_pb", mapping) - return - elif isinstance(mapping, type(self)): - # Just use the above logic on mapping's underlying pb. - self.__init__(mapping=mapping._pb, **kwargs) - return - elif isinstance(mapping, collections.abc.Mapping): - # Can't have side effects on mapping. - mapping = copy.copy(mapping) - # kwargs entries take priority for duplicate keys. - mapping.update(kwargs) - else: - # Sanity check: Did we get something not a map? Error if so. - raise TypeError( - "Invalid constructor input for %s: %r" - % ( - self.__class__.__name__, - mapping, - ) - ) - - params = {} - # Update the mapping to address any values that need to be - # coerced. - marshal = self._meta.marshal - for key, value in mapping.items(): - (key, pb_type) = self._get_pb_type_from_key(key) - if pb_type is None: - if ignore_unknown_fields: - continue - - raise ValueError( - "Unknown field for {}: {}".format(self.__class__.__name__, key) - ) - - try: - pb_value = marshal.to_proto(pb_type, value) - except ValueError: - # Underscores may be appended to field names - # that collide with python or proto-plus keywords. - # In case a key only exists with a `_` suffix, coerce the key - # to include the `_` suffix. It's not possible to - # natively define the same field with a trailing underscore in protobuf. - # See related issue - # https://github.com/googleapis/python-api-core/issues/227 - if isinstance(value, dict): - if _upb: - # In UPB, pb_type is MessageMeta which doesn't expose attrs like it used to in Python/CPP. - keys_to_update = [ - item - for item in value - if item not in pb_type.DESCRIPTOR.fields_by_name - and f"{item}_" in pb_type.DESCRIPTOR.fields_by_name - ] - else: - keys_to_update = [ - item - for item in value - if not hasattr(pb_type, item) - and hasattr(pb_type, f"{item}_") - ] - for item in keys_to_update: - value[f"{item}_"] = value.pop(item) - - pb_value = marshal.to_proto(pb_type, value) - - if pb_value is not None: - params[key] = pb_value - - # Create the internal protocol buffer. -> super().__setattr__("_pb", self._meta.pb(**params)) -E TypeError: Parameter to MergeFrom() must be instance of same class: expected got . - -../proxy/myenv/lib/python3.11/site-packages/proto/message.py:615: TypeError - -During handling of the above exception, another exception occurred: - -model = 'gemini-1.5-flash-preview-0514' -messages = [{'content': [{'text': 'Whats in this image?', 'type': 'text'}, {'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}, 'type': 'image_url'}], 'role': 'user'}] -timeout = 600.0, temperature = None, top_p = None, n = None, stream = None -stream_options = None, stop = None, max_tokens = None, presence_penalty = None -frequency_penalty = None, logit_bias = None, user = None, response_format = None -seed = None, tools = None, tool_choice = None, logprobs = None -top_logprobs = None, deployment_id = None, extra_headers = None -functions = None, function_call = None, base_url = None, api_version = None -api_key = None, model_list = None -kwargs = {'litellm_call_id': '7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', 'litellm_logging_obj': } -args = {'acompletion': False, 'api_base': None, 'api_key': None, 'api_version': None, ...} -api_base = None, mock_response = None, force_timeout = 600, logger_fn = None -verbose = False, custom_llm_provider = 'vertex_ai' - - @client - def completion( - model: str, - # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create - messages: List = [], - timeout: Optional[Union[float, str, httpx.Timeout]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - n: Optional[int] = None, - stream: Optional[bool] = None, - stream_options: Optional[dict] = None, - stop=None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[dict] = None, - user: Optional[str] = None, - # openai v1.0+ new params - response_format: Optional[dict] = None, - seed: Optional[int] = None, - tools: Optional[List] = None, - tool_choice: Optional[str] = None, - logprobs: Optional[bool] = None, - top_logprobs: Optional[int] = None, - deployment_id=None, - extra_headers: Optional[dict] = None, - # soon to be deprecated params by OpenAI - functions: Optional[List] = None, - function_call: Optional[str] = None, - # set api_base, api_version, api_key - base_url: Optional[str] = None, - api_version: Optional[str] = None, - api_key: Optional[str] = None, - model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. - # Optional liteLLM function params - **kwargs, - ) -> Union[ModelResponse, CustomStreamWrapper]: - """ - Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) - Parameters: - model (str): The name of the language model to use for text completion. see all supported LLMs: https://docs.litellm.ai/docs/providers/ - messages (List): A list of message objects representing the conversation context (default is an empty list). - - OPTIONAL PARAMS - functions (List, optional): A list of functions to apply to the conversation messages (default is an empty list). - function_call (str, optional): The name of the function to call within the conversation (default is an empty string). - temperature (float, optional): The temperature parameter for controlling the randomness of the output (default is 1.0). - top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). - n (int, optional): The number of completions to generate (default is 1). - stream (bool, optional): If True, return a streaming response (default is False). - stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true. - stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. - max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). - presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. - frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. - logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. - user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. - logprobs (bool, optional): Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message - top_logprobs (int, optional): An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. - metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. - api_base (str, optional): Base URL for the API (default is None). - api_version (str, optional): API version (default is None). - api_key (str, optional): API key (default is None). - model_list (list, optional): List of api base, version, keys - extra_headers (dict, optional): Additional headers to include in the request. - - LITELLM Specific Params - mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None). - custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock" - max_retries (int, optional): The number of retries to attempt (default is 0). - Returns: - ModelResponse: A response object containing the generated completion and associated metadata. - - Note: - - This function is used to perform completions() using the specified language model. - - It supports various optional parameters for customizing the completion behavior. - - If 'mock_response' is provided, a mock completion response is returned for testing or debugging. - """ - ######### unpacking kwargs ##################### - args = locals() - api_base = kwargs.get("api_base", None) - mock_response = kwargs.get("mock_response", None) - force_timeout = kwargs.get("force_timeout", 600) ## deprecated - logger_fn = kwargs.get("logger_fn", None) - verbose = kwargs.get("verbose", False) - custom_llm_provider = kwargs.get("custom_llm_provider", None) - litellm_logging_obj = kwargs.get("litellm_logging_obj", None) - id = kwargs.get("id", None) - metadata = kwargs.get("metadata", None) - model_info = kwargs.get("model_info", None) - proxy_server_request = kwargs.get("proxy_server_request", None) - fallbacks = kwargs.get("fallbacks", None) - headers = kwargs.get("headers", None) or extra_headers - num_retries = kwargs.get("num_retries", None) ## deprecated - max_retries = kwargs.get("max_retries", None) - context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) - organization = kwargs.get("organization", None) - ### CUSTOM MODEL COST ### - input_cost_per_token = kwargs.get("input_cost_per_token", None) - output_cost_per_token = kwargs.get("output_cost_per_token", None) - input_cost_per_second = kwargs.get("input_cost_per_second", None) - output_cost_per_second = kwargs.get("output_cost_per_second", None) - ### CUSTOM PROMPT TEMPLATE ### - initial_prompt_value = kwargs.get("initial_prompt_value", None) - roles = kwargs.get("roles", None) - final_prompt_value = kwargs.get("final_prompt_value", None) - bos_token = kwargs.get("bos_token", None) - eos_token = kwargs.get("eos_token", None) - preset_cache_key = kwargs.get("preset_cache_key", None) - hf_model_name = kwargs.get("hf_model_name", None) - supports_system_message = kwargs.get("supports_system_message", None) - ### TEXT COMPLETION CALLS ### - text_completion = kwargs.get("text_completion", False) - atext_completion = kwargs.get("atext_completion", False) - ### ASYNC CALLS ### - acompletion = kwargs.get("acompletion", False) - client = kwargs.get("client", None) - ### Admin Controls ### - no_log = kwargs.get("no-log", False) - ######## end of unpacking kwargs ########### - openai_params = [ - "functions", - "function_call", - "temperature", - "temperature", - "top_p", - "n", - "stream", - "stream_options", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "request_timeout", - "api_base", - "api_version", - "api_key", - "deployment_id", - "organization", - "base_url", - "default_headers", - "timeout", - "response_format", - "seed", - "tools", - "tool_choice", - "max_retries", - "logprobs", - "top_logprobs", - "extra_headers", - ] - litellm_params = [ - "metadata", - "acompletion", - "atext_completion", - "text_completion", - "caching", - "mock_response", - "api_key", - "api_version", - "api_base", - "force_timeout", - "logger_fn", - "verbose", - "custom_llm_provider", - "litellm_logging_obj", - "litellm_call_id", - "use_client", - "id", - "fallbacks", - "azure", - "headers", - "model_list", - "num_retries", - "context_window_fallback_dict", - "retry_policy", - "roles", - "final_prompt_value", - "bos_token", - "eos_token", - "request_timeout", - "complete_response", - "self", - "client", - "rpm", - "tpm", - "max_parallel_requests", - "input_cost_per_token", - "output_cost_per_token", - "input_cost_per_second", - "output_cost_per_second", - "hf_model_name", - "model_info", - "proxy_server_request", - "preset_cache_key", - "caching_groups", - "ttl", - "cache", - "no-log", - "base_model", - "stream_timeout", - "supports_system_message", - "region_name", - "allowed_model_region", - "model_config", - ] - - default_params = openai_params + litellm_params - non_default_params = { - k: v for k, v in kwargs.items() if k not in default_params - } # model-specific params - pass them straight to the model/provider - - try: - if base_url is not None: - api_base = base_url - if max_retries is not None: # openai allows openai.OpenAI(max_retries=3) - num_retries = max_retries - logging = litellm_logging_obj - fallbacks = fallbacks or litellm.model_fallbacks - if fallbacks is not None: - return completion_with_fallbacks(**args) - if model_list is not None: - deployments = [ - m["litellm_params"] for m in model_list if m["model_name"] == model - ] - return batch_completion_models(deployments=deployments, **args) - if litellm.model_alias_map and model in litellm.model_alias_map: - model = litellm.model_alias_map[ - model - ] # update the model to the actual value if an alias has been passed in - model_response = ModelResponse() - setattr(model_response, "usage", litellm.Usage()) - if ( - kwargs.get("azure", False) == True - ): # don't remove flag check, to remain backwards compatible for repos like Codium - custom_llm_provider = "azure" - if deployment_id != None: # azure llms - model = deployment_id - custom_llm_provider = "azure" - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( - model=model, - custom_llm_provider=custom_llm_provider, - api_base=api_base, - api_key=api_key, - ) - if model_response is not None and hasattr(model_response, "_hidden_params"): - model_response._hidden_params["custom_llm_provider"] = custom_llm_provider - model_response._hidden_params["region_name"] = kwargs.get( - "aws_region_name", None - ) # support region-based pricing for bedrock - - ### TIMEOUT LOGIC ### - timeout = timeout or kwargs.get("request_timeout", 600) or 600 - # set timeout for 10 minutes by default - if isinstance(timeout, httpx.Timeout) and not supports_httpx_timeout( - custom_llm_provider - ): - timeout = timeout.read or 600 # default 10 min timeout - elif not isinstance(timeout, httpx.Timeout): - timeout = float(timeout) # type: ignore - - ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### - if input_cost_per_token is not None and output_cost_per_token is not None: - litellm.register_model( - { - f"{custom_llm_provider}/{model}": { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider, - }, - model: { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider, - }, - } - ) - elif ( - input_cost_per_second is not None - ): # time based pricing just needs cost in place - output_cost_per_second = output_cost_per_second - litellm.register_model( - { - f"{custom_llm_provider}/{model}": { - "input_cost_per_second": input_cost_per_second, - "output_cost_per_second": output_cost_per_second, - "litellm_provider": custom_llm_provider, - }, - model: { - "input_cost_per_second": input_cost_per_second, - "output_cost_per_second": output_cost_per_second, - "litellm_provider": custom_llm_provider, - }, - } - ) - ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### - custom_prompt_dict = {} # type: ignore - if ( - initial_prompt_value - or roles - or final_prompt_value - or bos_token - or eos_token - ): - custom_prompt_dict = {model: {}} - if initial_prompt_value: - custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value - if roles: - custom_prompt_dict[model]["roles"] = roles - if final_prompt_value: - custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value - if bos_token: - custom_prompt_dict[model]["bos_token"] = bos_token - if eos_token: - custom_prompt_dict[model]["eos_token"] = eos_token - - if ( - supports_system_message is not None - and isinstance(supports_system_message, bool) - and supports_system_message == False - ): - messages = map_system_message_pt(messages=messages) - model_api_key = get_api_key( - llm_provider=custom_llm_provider, dynamic_api_key=api_key - ) # get the api key from the environment if required for the model - - if dynamic_api_key is not None: - api_key = dynamic_api_key - # check if user passed in any of the OpenAI optional params - optional_params = get_optional_params( - functions=functions, - function_call=function_call, - temperature=temperature, - top_p=top_p, - n=n, - stream=stream, - stream_options=stream_options, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - user=user, - # params to identify the model - model=model, - custom_llm_provider=custom_llm_provider, - response_format=response_format, - seed=seed, - tools=tools, - tool_choice=tool_choice, - max_retries=max_retries, - logprobs=logprobs, - top_logprobs=top_logprobs, - extra_headers=extra_headers, - **non_default_params, - ) - - if litellm.add_function_to_prompt and optional_params.get( - "functions_unsupported_model", None - ): # if user opts to add it to prompt, when API doesn't support function calling - functions_unsupported_model = optional_params.pop( - "functions_unsupported_model" - ) - messages = function_call_prompt( - messages=messages, functions=functions_unsupported_model - ) - - # For logging - save the values of the litellm-specific params passed in - litellm_params = get_litellm_params( - acompletion=acompletion, - api_key=api_key, - force_timeout=force_timeout, - logger_fn=logger_fn, - verbose=verbose, - custom_llm_provider=custom_llm_provider, - api_base=api_base, - litellm_call_id=kwargs.get("litellm_call_id", None), - model_alias_map=litellm.model_alias_map, - completion_call_id=id, - metadata=metadata, - model_info=model_info, - proxy_server_request=proxy_server_request, - preset_cache_key=preset_cache_key, - no_log=no_log, - input_cost_per_second=input_cost_per_second, - input_cost_per_token=input_cost_per_token, - output_cost_per_second=output_cost_per_second, - output_cost_per_token=output_cost_per_token, - ) - logging.update_environment_variables( - model=model, - user=user, - optional_params=optional_params, - litellm_params=litellm_params, - ) - if mock_response: - return mock_completion( - model, - messages, - stream=stream, - mock_response=mock_response, - logging=logging, - acompletion=acompletion, - ) - if custom_llm_provider == "azure": - # azure configs - api_type = get_secret("AZURE_API_TYPE") or "azure" - - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") - - api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_OPENAI_API_KEY") - or get_secret("AZURE_API_KEY") - ) - - azure_ad_token = optional_params.get("extra_body", {}).pop( - "azure_ad_token", None - ) or get_secret("AZURE_AD_TOKEN") - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.AzureOpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - response = azure_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - api_version=api_version, - api_type=api_type, - azure_ad_token=azure_ad_token, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, # type: ignore - client=client, # pass AsyncAzureOpenAI, AzureOpenAI client - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={ - "headers": headers, - "api_version": api_version, - "api_base": api_base, - }, - ) - elif custom_llm_provider == "azure_text": - # azure configs - api_type = get_secret("AZURE_API_TYPE") or "azure" - - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") - - api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_OPENAI_API_KEY") - or get_secret("AZURE_API_KEY") - ) - - azure_ad_token = optional_params.get("extra_body", {}).pop( - "azure_ad_token", None - ) or get_secret("AZURE_AD_TOKEN") - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.AzureOpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - response = azure_text_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - api_version=api_version, - api_type=api_type, - azure_ad_token=azure_ad_token, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, - client=client, # pass AsyncAzureOpenAI, AzureOpenAI client - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={ - "headers": headers, - "api_version": api_version, - "api_base": api_base, - }, - ) - elif ( - model in litellm.open_ai_chat_completion_models - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "deepinfra" - or custom_llm_provider == "perplexity" - or custom_llm_provider == "groq" - or custom_llm_provider == "deepseek" - or custom_llm_provider == "anyscale" - or custom_llm_provider == "mistral" - or custom_llm_provider == "openai" - or custom_llm_provider == "together_ai" - or custom_llm_provider in litellm.openai_compatible_providers - or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo - ): # allow user to make an openai call with a custom base - # note: if a user sets a custom base - we should ensure this works - # allow for the setting of dynamic and stateful api-bases - api_base = ( - api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) - openai.organization = ( - organization - or litellm.organization - or get_secret("OPENAI_ORGANIZATION") - or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 - ) - # set API KEY - api_key = ( - api_key - or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.OpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - try: - response = openai_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - model_response=model_response, - print_verbose=print_verbose, - api_key=api_key, - api_base=api_base, - acompletion=acompletion, - logging_obj=logging, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - timeout=timeout, # type: ignore - custom_prompt_dict=custom_prompt_dict, - client=client, # pass AsyncOpenAI, OpenAI client - organization=organization, - custom_llm_provider=custom_llm_provider, - ) - except Exception as e: - ## LOGGING - log the original exception returned - logging.post_call( - input=messages, - api_key=api_key, - original_response=str(e), - additional_args={"headers": headers}, - ) - raise e - - if optional_params.get("stream", False): - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={"headers": headers}, - ) - elif ( - custom_llm_provider == "text-completion-openai" - or "ft:babbage-002" in model - or "ft:davinci-002" in model # support for finetuned completion models - ): - openai.api_type = "openai" - - api_base = ( - api_base - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) - - openai.api_version = None - # set API KEY - - api_key = ( - api_key - or litellm.api_key - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.OpenAITextCompletionConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - if litellm.organization: - openai.organization = litellm.organization - - if ( - len(messages) > 0 - and "content" in messages[0] - and type(messages[0]["content"]) == list - ): - # text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content'] - # https://platform.openai.com/docs/api-reference/completions/create - prompt = messages[0]["content"] - else: - prompt = " ".join([message["content"] for message in messages]) # type: ignore - - ## COMPLETION CALL - _response = openai_text_completions.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - api_key=api_key, - api_base=api_base, - acompletion=acompletion, - client=client, # pass AsyncOpenAI, OpenAI client - logging_obj=logging, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - timeout=timeout, # type: ignore - ) - - if ( - optional_params.get("stream", False) == False - and acompletion == False - and text_completion == False - ): - # convert to chat completion response - _response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( - response_object=_response, model_response_object=model_response - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=_response, - additional_args={"headers": headers}, - ) - response = _response - elif ( - "replicate" in model - or custom_llm_provider == "replicate" - or model in litellm.replicate_models - ): - # Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN") - replicate_key = None - replicate_key = ( - api_key - or litellm.replicate_key - or litellm.api_key - or get_secret("REPLICATE_API_KEY") - or get_secret("REPLICATE_API_TOKEN") - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("REPLICATE_API_BASE") - or "https://api.replicate.com/v1" - ) - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - model_response = replicate.completion( # type: ignore - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=replicate_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - acompletion=acompletion, - ) - - if optional_params.get("stream", False) == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=replicate_key, - original_response=model_response, - ) - - response = model_response - elif ( - "clarifai" in model - or custom_llm_provider == "clarifai" - or model in litellm.clarifai_models - ): - clarifai_key = None - clarifai_key = ( - api_key - or litellm.clarifai_key - or litellm.api_key - or get_secret("CLARIFAI_API_KEY") - or get_secret("CLARIFAI_API_TOKEN") - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("CLARIFAI_API_BASE") - or "https://api.clarifai.com/v2" - ) - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - model_response = clarifai.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - acompletion=acompletion, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=clarifai_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=model_response, - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=clarifai_key, - original_response=model_response, - ) - response = model_response - - elif custom_llm_provider == "anthropic": - api_key = ( - api_key - or litellm.anthropic_key - or litellm.api_key - or os.environ.get("ANTHROPIC_API_KEY") - ) - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - if (model == "claude-2") or (model == "claude-instant-1"): - # call anthropic /completion, only use this route for claude-2, claude-instant-1 - api_base = ( - api_base - or litellm.api_base - or get_secret("ANTHROPIC_API_BASE") - or "https://api.anthropic.com/v1/complete" - ) - response = anthropic_text_completions.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - headers=headers, - ) - else: - # call /messages - # default route for all anthropic models - api_base = ( - api_base - or litellm.api_base - or get_secret("ANTHROPIC_API_BASE") - or "https://api.anthropic.com/v1/messages" - ) - response = anthropic_chat_completions.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - headers=headers, - ) - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - ) - response = response - elif custom_llm_provider == "nlp_cloud": - nlp_cloud_key = ( - api_key - or litellm.nlp_cloud_key - or get_secret("NLP_CLOUD_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("NLP_CLOUD_API_BASE") - or "https://api.nlpcloud.io/v1/gpu/" - ) - - response = nlp_cloud.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=nlp_cloud_key, - logging_obj=logging, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - response, - model, - custom_llm_provider="nlp_cloud", - logging_obj=logging, - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - ) - - response = response - elif custom_llm_provider == "aleph_alpha": - aleph_alpha_key = ( - api_key - or litellm.aleph_alpha_key - or get_secret("ALEPH_ALPHA_API_KEY") - or get_secret("ALEPHALPHA_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("ALEPH_ALPHA_API_BASE") - or "https://api.aleph-alpha.com/complete" - ) - - model_response = aleph_alpha.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - default_max_tokens_to_sample=litellm.max_tokens, - api_key=aleph_alpha_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="aleph_alpha", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "cohere": - cohere_key = ( - api_key - or litellm.cohere_key - or get_secret("COHERE_API_KEY") - or get_secret("CO_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("COHERE_API_BASE") - or "https://api.cohere.ai/v1/generate" - ) - - model_response = cohere.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=cohere_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="cohere", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "cohere_chat": - cohere_key = ( - api_key - or litellm.cohere_key - or get_secret("COHERE_API_KEY") - or get_secret("CO_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("COHERE_API_BASE") - or "https://api.cohere.ai/v1/chat" - ) - - model_response = cohere_chat.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=cohere_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="cohere_chat", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "maritalk": - maritalk_key = ( - api_key - or litellm.maritalk_key - or get_secret("MARITALK_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("MARITALK_API_BASE") - or "https://chat.maritaca.ai/api/chat/inference" - ) - - model_response = maritalk.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=maritalk_key, - logging_obj=logging, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="maritalk", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "huggingface": - custom_llm_provider = "huggingface" - huggingface_key = ( - api_key - or litellm.huggingface_key - or os.environ.get("HF_TOKEN") - or os.environ.get("HUGGINGFACE_API_KEY") - or litellm.api_key - ) - hf_headers = headers or litellm.headers - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - model_response = huggingface.completion( - model=model, - messages=messages, - api_base=api_base, # type: ignore - headers=hf_headers, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=huggingface_key, - acompletion=acompletion, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - timeout=timeout, # type: ignore - ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion is False - ): - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="huggingface", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "oobabooga": - custom_llm_provider = "oobabooga" - model_response = oobabooga.completion( - model=model, - messages=messages, - model_response=model_response, - api_base=api_base, # type: ignore - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - api_key=None, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - ) - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="oobabooga", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "openrouter": - api_base = api_base or litellm.api_base or "https://openrouter.ai/api/v1" - - api_key = ( - api_key - or litellm.api_key - or litellm.openrouter_key - or get_secret("OPENROUTER_API_KEY") - or get_secret("OR_API_KEY") - ) - - openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai" - - openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM" - - headers = ( - headers - or litellm.headers - or { - "HTTP-Referer": openrouter_site_url, - "X-Title": openrouter_app_name, - } - ) - - ## Load Config - config = openrouter.OpenrouterConfig.get_config() - for k, v in config.items(): - if k == "extra_body": - # we use openai 'extra_body' to pass openrouter specific params - transforms, route, models - if "extra_body" in optional_params: - optional_params[k].update(v) - else: - optional_params[k] = v - elif k not in optional_params: - optional_params[k] = v - - data = {"model": model, "messages": messages, **optional_params} - - ## COMPLETION CALL - response = openai_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, # type: ignore - ) - ## LOGGING - logging.post_call( - input=messages, api_key=openai.api_key, original_response=response - ) - elif ( - custom_llm_provider == "together_ai" - or ("togethercomputer" in model) - or (model in litellm.together_ai_models) - ): - """ - Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility - """ - custom_llm_provider = "together_ai" - together_ai_key = ( - api_key - or litellm.togetherai_api_key - or get_secret("TOGETHER_AI_TOKEN") - or get_secret("TOGETHERAI_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("TOGETHERAI_API_BASE") - or "https://api.together.xyz/inference" - ) - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - model_response = together_ai.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=together_ai_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - ) - if ( - "stream_tokens" in optional_params - and optional_params["stream_tokens"] == True - ): - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="together_ai", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "palm": - palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key - - # palm does not support streaming as yet :( - model_response = palm.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=palm_api_key, - logging_obj=logging, - ) - # fake palm streaming - if "stream" in optional_params and optional_params["stream"] == True: - # fake streaming for palm - resp_string = model_response["choices"][0]["message"]["content"] - response = CustomStreamWrapper( - resp_string, model, custom_llm_provider="palm", logging_obj=logging - ) - return response - response = model_response - elif custom_llm_provider == "gemini": - gemini_api_key = ( - api_key - or get_secret("GEMINI_API_KEY") - or get_secret("PALM_API_KEY") # older palm api key should also work - or litellm.api_key - ) - - # palm does not support streaming as yet :( - model_response = gemini.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=gemini_api_key, - logging_obj=logging, - acompletion=acompletion, - custom_prompt_dict=custom_prompt_dict, - ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion == False - ): - response = CustomStreamWrapper( - iter(model_response), - model, - custom_llm_provider="gemini", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "vertex_ai": - vertex_ai_project = ( - optional_params.pop("vertex_project", None) - or optional_params.pop("vertex_ai_project", None) - or litellm.vertex_project - or get_secret("VERTEXAI_PROJECT") - ) - vertex_ai_location = ( - optional_params.pop("vertex_location", None) - or optional_params.pop("vertex_ai_location", None) - or litellm.vertex_location - or get_secret("VERTEXAI_LOCATION") - ) - vertex_credentials = ( - optional_params.pop("vertex_credentials", None) - or optional_params.pop("vertex_ai_credentials", None) - or get_secret("VERTEXAI_CREDENTIALS") - ) - new_params = deepcopy(optional_params) - if "claude-3" in model: - model_response = vertex_ai_anthropic.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=new_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - vertex_credentials=vertex_credentials, - logging_obj=logging, - acompletion=acompletion, - ) - else: -> model_response = vertex_ai.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=new_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - vertex_credentials=vertex_credentials, - logging_obj=logging, - acompletion=acompletion, - ) - -../main.py:1824: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -model = 'gemini-1.5-flash-preview-0514' -messages = [{'content': [{'text': 'Whats in this image?', 'type': 'text'}, {'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}, 'type': 'image_url'}], 'role': 'user'}] -model_response = ModelResponse(id='chatcmpl-722df0e7-4e2d-44e6-9e2c-49823faa0189', choices=[Choices(finish_reason='stop', index=0, mess... role='assistant'))], created=1716145725, model=None, object='chat.completion', system_fingerprint=None, usage=Usage()) -print_verbose = -encoding = -logging_obj = -vertex_project = None, vertex_location = None, vertex_credentials = None -optional_params = {} -litellm_params = {'acompletion': False, 'api_base': '', 'api_key': None, 'completion_call_id': None, ...} -logger_fn = None, acompletion = False - - def completion( - model: str, - messages: list, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - vertex_project=None, - vertex_location=None, - vertex_credentials=None, - optional_params=None, - litellm_params=None, - logger_fn=None, - acompletion: bool = False, - ): - try: - import vertexai - except: - raise VertexAIError( - status_code=400, - message="vertexai import failed please run `pip install google-cloud-aiplatform`", - ) - - if not ( - hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") - ): - raise VertexAIError( - status_code=400, - message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", - ) - try: - from vertexai.preview.language_models import ( - ChatModel, - CodeChatModel, - InputOutputTextPair, - ) - from vertexai.language_models import TextGenerationModel, CodeGenerationModel - from vertexai.preview.generative_models import ( - GenerativeModel, - Part, - GenerationConfig, - ) - from google.cloud import aiplatform # type: ignore - from google.protobuf import json_format # type: ignore - from google.protobuf.struct_pb2 import Value # type: ignore - from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore - import google.auth # type: ignore - import proto # type: ignore - - ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 - print_verbose( - f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" - ) - if vertex_credentials is not None and isinstance(vertex_credentials, str): - import google.oauth2.service_account - - json_obj = json.loads(vertex_credentials) - - creds = google.oauth2.service_account.Credentials.from_service_account_info( - json_obj, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - else: - creds, _ = google.auth.default(quota_project_id=vertex_project) - print_verbose( - f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" - ) - vertexai.init( - project=vertex_project, location=vertex_location, credentials=creds - ) - - ## Load Config - config = litellm.VertexAIConfig.get_config() - for k, v in config.items(): - if k not in optional_params: - optional_params[k] = v - - ## Process safety settings into format expected by vertex AI - safety_settings = None - if "safety_settings" in optional_params: - safety_settings = optional_params.pop("safety_settings") - if not isinstance(safety_settings, list): - raise ValueError("safety_settings must be a list") - if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict): - raise ValueError("safety_settings must be a list of dicts") - safety_settings = [ - gapic_content_types.SafetySetting(x) for x in safety_settings - ] - - # vertexai does not use an API key, it looks for credentials.json in the environment - - prompt = " ".join( - [ - message["content"] - for message in messages - if isinstance(message["content"], str) - ] - ) - - mode = "" - - request_str = "" - response_obj = None - async_client = None - instances = None - client_options = { - "api_endpoint": f"{vertex_location}-aiplatform.googleapis.com" - } - if ( - model in litellm.vertex_language_models - or model in litellm.vertex_vision_models - ): - llm_model = GenerativeModel(model) - mode = "vision" - request_str += f"llm_model = GenerativeModel({model})\n" - elif model in litellm.vertex_chat_models: - llm_model = ChatModel.from_pretrained(model) - mode = "chat" - request_str += f"llm_model = ChatModel.from_pretrained({model})\n" - elif model in litellm.vertex_text_models: - llm_model = TextGenerationModel.from_pretrained(model) - mode = "text" - request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n" - elif model in litellm.vertex_code_text_models: - llm_model = CodeGenerationModel.from_pretrained(model) - mode = "text" - request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" - elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models - llm_model = CodeChatModel.from_pretrained(model) - mode = "chat" - request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" - elif model == "private": - mode = "private" - model = optional_params.pop("model_id", None) - # private endpoint requires a dict instead of JSON - instances = [optional_params.copy()] - instances[0]["prompt"] = prompt - llm_model = aiplatform.PrivateEndpoint( - endpoint_name=model, - project=vertex_project, - location=vertex_location, - ) - request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n" - else: # assume vertex model garden on public endpoint - mode = "custom" - - instances = [optional_params.copy()] - instances[0]["prompt"] = prompt - instances = [ - json_format.ParseDict(instance_dict, Value()) - for instance_dict in instances - ] - # Will determine the API used based on async parameter - llm_model = None - - # NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now - if acompletion == True: - data = { - "llm_model": llm_model, - "mode": mode, - "prompt": prompt, - "logging_obj": logging_obj, - "request_str": request_str, - "model": model, - "model_response": model_response, - "encoding": encoding, - "messages": messages, - "print_verbose": print_verbose, - "client_options": client_options, - "instances": instances, - "vertex_location": vertex_location, - "vertex_project": vertex_project, - "safety_settings": safety_settings, - **optional_params, - } - if optional_params.get("stream", False) is True: - # async streaming - return async_streaming(**data) - - return async_completion(**data) - - if mode == "vision": - print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call") - print_verbose(f"\nProcessing input messages = {messages}") - tools = optional_params.pop("tools", None) - content = _gemini_convert_messages_text(messages=messages) - stream = optional_params.pop("stream", False) - if stream == True: - request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - model_response = llm_model.generate_content( - contents={"content": content}, - generation_config=optional_params, - safety_settings=safety_settings, - stream=True, - tools=tools, - ) - - return model_response - - request_str += f"response = llm_model.generate_content({content})\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - ## LLM Call - response = llm_model.generate_content( - contents=content, - generation_config=optional_params, - safety_settings=safety_settings, - tools=tools, - ) - - if tools is not None and bool( - getattr(response.candidates[0].content.parts[0], "function_call", None) - ): - function_call = response.candidates[0].content.parts[0].function_call - args_dict = {} - - # Check if it's a RepeatedComposite instance - for key, val in function_call.args.items(): - if isinstance( - val, proto.marshal.collections.repeated.RepeatedComposite - ): - # If so, convert to list - args_dict[key] = [v for v in val] - else: - args_dict[key] = val - - try: - args_str = json.dumps(args_dict) - except Exception as e: - raise VertexAIError(status_code=422, message=str(e)) - message = litellm.Message( - content=None, - tool_calls=[ - { - "id": f"call_{str(uuid.uuid4())}", - "function": { - "arguments": args_str, - "name": function_call.name, - }, - "type": "function", - } - ], - ) - completion_response = message - else: - completion_response = response.text - response_obj = response._raw_response - optional_params["tools"] = tools - elif mode == "chat": - chat = llm_model.start_chat() - request_str += f"chat = llm_model.start_chat()\n" - - if "stream" in optional_params and optional_params["stream"] == True: - # NOTE: VertexAI does not accept stream=True as a param and raises an error, - # we handle this by removing 'stream' from optional params and sending the request - # after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format - optional_params.pop( - "stream", None - ) # vertex ai raises an error when passing stream in optional params - request_str += ( - f"chat.send_message_streaming({prompt}, **{optional_params})\n" - ) - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - model_response = chat.send_message_streaming(prompt, **optional_params) - - return model_response - - request_str += f"chat.send_message({prompt}, **{optional_params}).text\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - completion_response = chat.send_message(prompt, **optional_params).text - elif mode == "text": - if "stream" in optional_params and optional_params["stream"] == True: - optional_params.pop( - "stream", None - ) # See note above on handling streaming for vertex ai - request_str += ( - f"llm_model.predict_streaming({prompt}, **{optional_params})\n" - ) - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - model_response = llm_model.predict_streaming(prompt, **optional_params) - - return model_response - - request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - completion_response = llm_model.predict(prompt, **optional_params).text - elif mode == "custom": - """ - Vertex AI Model Garden - """ - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - llm_model = aiplatform.gapic.PredictionServiceClient( - client_options=client_options - ) - request_str += f"llm_model = aiplatform.gapic.PredictionServiceClient(client_options={client_options})\n" - endpoint_path = llm_model.endpoint_path( - project=vertex_project, location=vertex_location, endpoint=model - ) - request_str += ( - f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n" - ) - response = llm_model.predict( - endpoint=endpoint_path, instances=instances - ).predictions - - completion_response = response[0] - if ( - isinstance(completion_response, str) - and "\nOutput:\n" in completion_response - ): - completion_response = completion_response.split("\nOutput:\n", 1)[1] - if "stream" in optional_params and optional_params["stream"] == True: - response = TextStreamer(completion_response) - return response - elif mode == "private": - """ - Vertex AI Model Garden deployed on private endpoint - """ - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - request_str += f"llm_model.predict(instances={instances})\n" - response = llm_model.predict(instances=instances).predictions - - completion_response = response[0] - if ( - isinstance(completion_response, str) - and "\nOutput:\n" in completion_response - ): - completion_response = completion_response.split("\nOutput:\n", 1)[1] - if "stream" in optional_params and optional_params["stream"] == True: - response = TextStreamer(completion_response) - return response - - ## LOGGING - logging_obj.post_call( - input=prompt, api_key=None, original_response=completion_response - ) - - ## RESPONSE OBJECT - if isinstance(completion_response, litellm.Message): - model_response["choices"][0]["message"] = completion_response - elif len(str(completion_response)) > 0: - model_response["choices"][0]["message"]["content"] = str( - completion_response - ) - model_response["created"] = int(time.time()) - model_response["model"] = model - ## CALCULATING USAGE - if model in litellm.vertex_language_models and response_obj is not None: - model_response["choices"][0].finish_reason = map_finish_reason( - response_obj.candidates[0].finish_reason.name - ) - usage = Usage( - prompt_tokens=response_obj.usage_metadata.prompt_token_count, - completion_tokens=response_obj.usage_metadata.candidates_token_count, - total_tokens=response_obj.usage_metadata.total_token_count, - ) - else: - # init prompt tokens - # this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter - prompt_tokens, completion_tokens, total_tokens = 0, 0, 0 - if response_obj is not None: - if hasattr(response_obj, "usage_metadata") and hasattr( - response_obj.usage_metadata, "prompt_token_count" - ): - prompt_tokens = response_obj.usage_metadata.prompt_token_count - completion_tokens = ( - response_obj.usage_metadata.candidates_token_count - ) - else: - prompt_tokens = len(encoding.encode(prompt)) - completion_tokens = len( - encoding.encode( - model_response["choices"][0]["message"].get("content", "") - ) - ) - - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - setattr(model_response, "usage", usage) - return model_response - except Exception as e: - if isinstance(e, VertexAIError): - raise e -> raise VertexAIError(status_code=500, message=str(e)) -E litellm.llms.vertex_ai.VertexAIError: Parameter to MergeFrom() must be instance of same class: expected got . - -../llms/vertex_ai.py:971: VertexAIError - -During handling of the above exception, another exception occurred: - -args = () -kwargs = {'litellm_call_id': '7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', 'litellm_logging_obj': -call_type = 'completion', model = 'vertex_ai/gemini-1.5-flash-preview-0514' -k = 'litellm_logging_obj' - - @wraps(original_function) - def wrapper(*args, **kwargs): - # DO NOT MOVE THIS. It always needs to run first - # Check if this is an async function. If so only execute the async function - if ( - kwargs.get("acompletion", False) == True - or kwargs.get("aembedding", False) == True - or kwargs.get("aimg_generation", False) == True - or kwargs.get("amoderation", False) == True - or kwargs.get("atext_completion", False) == True - or kwargs.get("atranscription", False) == True - ): - # [OPTIONAL] CHECK MAX RETRIES / REQUEST - if litellm.num_retries_per_request is not None: - # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] - previous_models = kwargs.get("metadata", {}).get( - "previous_models", None - ) - if previous_models is not None: - if litellm.num_retries_per_request <= len(previous_models): - raise Exception(f"Max retries per request hit!") - - # MODEL CALL - result = original_function(*args, **kwargs) - if "stream" in kwargs and kwargs["stream"] == True: - if ( - "complete_response" in kwargs - and kwargs["complete_response"] == True - ): - chunks = [] - for idx, chunk in enumerate(result): - chunks.append(chunk) - return litellm.stream_chunk_builder( - chunks, messages=kwargs.get("messages", None) - ) - else: - return result - - return result - - # Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print - print_args_passed_to_litellm(original_function, args, kwargs) - start_time = datetime.datetime.now() - result = None - logging_obj = kwargs.get("litellm_logging_obj", None) - - # only set litellm_call_id if its not in kwargs - call_type = original_function.__name__ - if "litellm_call_id" not in kwargs: - kwargs["litellm_call_id"] = str(uuid.uuid4()) - try: - model = args[0] if len(args) > 0 else kwargs["model"] - except: - model = None - if ( - call_type != CallTypes.image_generation.value - and call_type != CallTypes.text_completion.value - ): - raise ValueError("model param not passed in.") - - try: - if logging_obj is None: - logging_obj, kwargs = function_setup( - original_function.__name__, rules_obj, start_time, *args, **kwargs - ) - kwargs["litellm_logging_obj"] = logging_obj - - # CHECK FOR 'os.environ/' in kwargs - for k, v in kwargs.items(): - if v is not None and isinstance(v, str) and v.startswith("os.environ/"): - kwargs[k] = litellm.get_secret(v) - # [OPTIONAL] CHECK BUDGET - if litellm.max_budget: - if litellm._current_cost > litellm.max_budget: - raise BudgetExceededError( - current_cost=litellm._current_cost, - max_budget=litellm.max_budget, - ) - - # [OPTIONAL] CHECK MAX RETRIES / REQUEST - if litellm.num_retries_per_request is not None: - # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] - previous_models = kwargs.get("metadata", {}).get( - "previous_models", None - ) - if previous_models is not None: - if litellm.num_retries_per_request <= len(previous_models): - raise Exception(f"Max retries per request hit!") - - # [OPTIONAL] CHECK CACHE - print_verbose( - f"SYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache')['no-cache']: {kwargs.get('cache', {}).get('no-cache', False)}" - ) - # if caching is false or cache["no-cache"]==True, don't run this - if ( - ( - ( - ( - kwargs.get("caching", None) is None - and litellm.cache is not None - ) - or kwargs.get("caching", False) == True - ) - and kwargs.get("cache", {}).get("no-cache", False) != True - ) - and kwargs.get("aembedding", False) != True - and kwargs.get("atext_completion", False) != True - and kwargs.get("acompletion", False) != True - and kwargs.get("aimg_generation", False) != True - and kwargs.get("atranscription", False) != True - ): # allow users to control returning cached responses from the completion function - # checking cache - print_verbose(f"INSIDE CHECKING CACHE") - if ( - litellm.cache is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types - ): - print_verbose(f"Checking Cache") - preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = ( - preset_cache_key # for streaming calls, we need to pass the preset_cache_key - ) - cached_result = litellm.cache.get_cache(*args, **kwargs) - if cached_result != None: - if "detail" in cached_result: - # implies an error occurred - pass - else: - call_type = original_function.__name__ - print_verbose( - f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}" - ) - if call_type == CallTypes.completion.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=ModelResponse(), - stream=kwargs.get("stream", False), - ) - - if kwargs.get("stream", False) == True: - cached_result = CustomStreamWrapper( - completion_stream=cached_result, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, - ) - elif call_type == CallTypes.embedding.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - response_type="embedding", - ) - - # LOG SUCCESS - cache_hit = True - end_time = datetime.datetime.now() - ( - model, - custom_llm_provider, - dynamic_api_key, - api_base, - ) = litellm.get_llm_provider( - model=model, - custom_llm_provider=kwargs.get( - "custom_llm_provider", None - ), - api_base=kwargs.get("api_base", None), - api_key=kwargs.get("api_key", None), - ) - print_verbose( - f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" - ) - logging_obj.update_environment_variables( - model=model, - user=kwargs.get("user", None), - optional_params={}, - litellm_params={ - "logger_fn": kwargs.get("logger_fn", None), - "acompletion": False, - "metadata": kwargs.get("metadata", {}), - "model_info": kwargs.get("model_info", {}), - "proxy_server_request": kwargs.get( - "proxy_server_request", None - ), - "preset_cache_key": kwargs.get( - "preset_cache_key", None - ), - "stream_response": kwargs.get( - "stream_response", {} - ), - }, - input=kwargs.get("messages", ""), - api_key=kwargs.get("api_key", None), - original_response=str(cached_result), - additional_args=None, - stream=kwargs.get("stream", False), - ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() - return cached_result - - # CHECK MAX TOKENS - if ( - kwargs.get("max_tokens", None) is not None - and model is not None - and litellm.modify_params - == True # user is okay with params being modified - and ( - call_type == CallTypes.acompletion.value - or call_type == CallTypes.completion.value - ) - ): - try: - base_model = model - if kwargs.get("hf_model_name", None) is not None: - base_model = f"huggingface/{kwargs.get('hf_model_name')}" - max_output_tokens = ( - get_max_tokens(model=base_model) or 4096 - ) # assume min context window is 4k tokens - user_max_tokens = kwargs.get("max_tokens") - ## Scenario 1: User limit + prompt > model limit - messages = None - if len(args) > 1: - messages = args[1] - elif kwargs.get("messages", None): - messages = kwargs["messages"] - input_tokens = token_counter(model=base_model, messages=messages) - input_tokens += max( - 0.1 * input_tokens, 10 - ) # give at least a 10 token buffer. token counting can be imprecise. - if input_tokens > max_output_tokens: - pass # allow call to fail normally - elif user_max_tokens + input_tokens > max_output_tokens: - user_max_tokens = max_output_tokens - input_tokens - print_verbose(f"user_max_tokens: {user_max_tokens}") - kwargs["max_tokens"] = int( - round(user_max_tokens) - ) # make sure max tokens is always an int - except Exception as e: - print_verbose(f"Error while checking max token limit: {str(e)}") - # MODEL CALL -> result = original_function(*args, **kwargs) - -../utils.py:3211: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../main.py:2368: in completion - raise exception_type( -../utils.py:9709: in exception_type - raise e -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -model = 'gemini-1.5-flash-preview-0514' -original_exception = VertexAIError("Parameter to MergeFrom() must be instance of same class: expected got .") -custom_llm_provider = 'vertex_ai' -completion_kwargs = {'acompletion': False, 'api_base': None, 'api_key': None, 'api_version': None, ...} -extra_kwargs = {'litellm_call_id': '7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', 'litellm_logging_obj': } - - def exception_type( - model, - original_exception, - custom_llm_provider, - completion_kwargs={}, - extra_kwargs={}, - ): - global user_logger_fn, liteDebuggerClient - exception_mapping_worked = False - if litellm.suppress_debug_info is False: - print() # noqa - print( # noqa - "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" # noqa - ) # noqa - print( # noqa - "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa - ) # noqa - print() # noqa - try: - if model: - error_str = str(original_exception) - if isinstance(original_exception, BaseException): - exception_type = type(original_exception).__name__ - else: - exception_type = "" - - ################################################################################ - # Common Extra information needed for all providers - # We pass num retries, api_base, vertex_deployment etc to the exception here - ################################################################################ - extra_information = "" - try: - _api_base = litellm.get_api_base( - model=model, optional_params=extra_kwargs - ) - messages = litellm.get_first_chars_messages(kwargs=completion_kwargs) - _vertex_project = extra_kwargs.get("vertex_project") - _vertex_location = extra_kwargs.get("vertex_location") - _metadata = extra_kwargs.get("metadata", {}) or {} - _model_group = _metadata.get("model_group") - _deployment = _metadata.get("deployment") - extra_information = f"\nModel: {model}" - if _api_base: - extra_information += f"\nAPI Base: {_api_base}" - if messages and len(messages) > 0: - extra_information += f"\nMessages: {messages}" - - if _model_group is not None: - extra_information += f"\nmodel_group: {_model_group}\n" - if _deployment is not None: - extra_information += f"\ndeployment: {_deployment}\n" - if _vertex_project is not None: - extra_information += f"\nvertex_project: {_vertex_project}\n" - if _vertex_location is not None: - extra_information += f"\nvertex_location: {_vertex_location}\n" - - # on litellm proxy add key name + team to exceptions - extra_information = _add_key_name_and_team_to_alert( - request_info=extra_information, metadata=_metadata - ) - except: - # DO NOT LET this Block raising the original exception - pass - - ################################################################################ - # End of Common Extra information Needed for all providers - ################################################################################ - - ################################################################################ - #################### Start of Provider Exception mapping #################### - ################################################################################ - - if "Request Timeout Error" in error_str or "Request timed out" in error_str: - exception_mapping_worked = True - raise Timeout( - message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - - if ( - custom_llm_provider == "openai" - or custom_llm_provider == "text-completion-openai" - or custom_llm_provider == "custom_openai" - or custom_llm_provider in litellm.openai_compatible_providers - ): - # custom_llm_provider is openai, make it OpenAI - if hasattr(original_exception, "message"): - message = original_exception.message - else: - message = str(original_exception) - if message is not None and isinstance(message, str): - message = message.replace("OPENAI", custom_llm_provider.upper()) - message = message.replace("openai", custom_llm_provider) - message = message.replace("OpenAI", custom_llm_provider) - if custom_llm_provider == "openai": - exception_provider = "OpenAI" + "Exception" - else: - exception_provider = ( - custom_llm_provider[0].upper() - + custom_llm_provider[1:] - + "Exception" - ) - - if "This model's maximum context length is" in error_str: - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "model_not_found" in error_str - ): - exception_mapping_worked = True - raise NotFoundError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "content_policy_violation" in error_str - ): - exception_mapping_worked = True - raise ContentPolicyViolationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "Incorrect API key provided" not in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "Request too large" in error_str: - raise RateLimitError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" - in error_str - ): - exception_mapping_worked = True - raise AuthenticationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "Mistral API raised a streaming error" in error_str: - exception_mapping_worked = True - _request = httpx.Request( - method="POST", url="https://api.openai.com/v1" - ) - raise APIError( - status_code=500, - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - request=_request, - litellm_debug_info=extra_information, - ) - elif hasattr(original_exception, "status_code"): - exception_mapping_worked = True - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 404: - exception_mapping_worked = True - raise NotFoundError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 422: - exception_mapping_worked = True - raise BadRequestError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 503: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 504: # gateway timeout error - exception_mapping_worked = True - raise Timeout( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - else: - exception_mapping_worked = True - raise APIError( - status_code=original_exception.status_code, - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - request=original_exception.request, - litellm_debug_info=extra_information, - ) - else: - # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors - raise APIConnectionError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - litellm_debug_info=extra_information, - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ) - elif custom_llm_provider == "anthropic": # one of the anthropics - if hasattr(original_exception, "message"): - if ( - "prompt is too long" in original_exception.message - or "prompt: length" in original_exception.message - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - if "Invalid API Key" in original_exception.message: - exception_mapping_worked = True - raise AuthenticationError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - if hasattr(original_exception, "status_code"): - print_verbose(f"status_code: {original_exception.status_code}") - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"AnthropicException - {original_exception.message}", - llm_provider="anthropic", - model=model, - response=original_exception.response, - ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 413 - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"AnthropicException - {original_exception.message}", - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"AnthropicException - {original_exception.message}", - model=model, - llm_provider="anthropic", - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"AnthropicException - {original_exception.message}", - llm_provider="anthropic", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 500: - exception_mapping_worked = True - raise APIError( - status_code=500, - message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.", - llm_provider="anthropic", - model=model, - request=original_exception.request, - ) - elif custom_llm_provider == "replicate": - if "Incorrect authentication token" in error_str: - exception_mapping_worked = True - raise AuthenticationError( - message=f"ReplicateException - {error_str}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif "input is too long" in error_str: - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"ReplicateException - {error_str}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif exception_type == "ModelError": - exception_mapping_worked = True - raise BadRequestError( - message=f"ReplicateException - {error_str}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif "Request was throttled" in error_str: - exception_mapping_worked = True - raise RateLimitError( - message=f"ReplicateException - {error_str}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 422 - or original_exception.status_code == 413 - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"ReplicateException - {original_exception.message}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"ReplicateException - {original_exception.message}", - model=model, - llm_provider="replicate", - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 500: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - exception_mapping_worked = True - raise APIError( - status_code=500, - message=f"ReplicateException - {str(original_exception)}", - llm_provider="replicate", - model=model, - request=httpx.Request( - method="POST", - url="https://api.replicate.com/v1/deployments", - ), - ) - elif custom_llm_provider == "watsonx": - if "token_quota_reached" in error_str: - exception_mapping_worked = True - raise RateLimitError( - message=f"WatsonxException: Rate Limit Errror - {error_str}", - llm_provider="watsonx", - model=model, - response=original_exception.response, - ) - elif custom_llm_provider == "predibase": - if "authorization denied for" in error_str: - exception_mapping_worked = True - - # Predibase returns the raw API Key in the response - this block ensures it's not returned in the exception - if ( - error_str is not None - and isinstance(error_str, str) - and "bearer" in error_str.lower() - ): - # only keep the first 10 chars after the occurnence of "bearer" - _bearer_token_start_index = error_str.lower().find("bearer") - error_str = error_str[: _bearer_token_start_index + 14] - error_str += "XXXXXXX" + '"' - - raise AuthenticationError( - message=f"PredibaseException: Authentication Error - {error_str}", - llm_provider="predibase", - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif custom_llm_provider == "bedrock": - if ( - "too many tokens" in error_str - or "expected maxLength:" in error_str - or "Input is too long" in error_str - or "prompt: length: 1.." in error_str - or "Too many input tokens" in error_str - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"BedrockException: Context Window Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif "Malformed input request" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "Unable to locate credentials" in error_str - or "The security token included in the request is invalid" - in error_str - ): - exception_mapping_worked = True - raise AuthenticationError( - message=f"BedrockException Invalid Authentication - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif "AccessDeniedException" in error_str: - exception_mapping_worked = True - raise PermissionDeniedError( - message=f"BedrockException PermissionDeniedError - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "throttlingException" in error_str - or "ThrottlingException" in error_str - ): - exception_mapping_worked = True - raise RateLimitError( - message=f"BedrockException: Rate Limit Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "Connect timeout on endpoint URL" in error_str - or "timed out" in error_str - ): - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException: Timeout Error - {error_str}", - model=model, - llm_provider="bedrock", - ) - elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 500: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=httpx.Response( - status_code=500, - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ), - ) - elif original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 400: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 404: - exception_mapping_worked = True - raise NotFoundError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 422: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 503: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 504: # gateway timeout error - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif custom_llm_provider == "sagemaker": - if "Unable to locate credentials" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"SagemakerException - {error_str}", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif ( - "Input validation error: `best_of` must be > 0 and <= 2" - in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif ( - "`inputs` tokens + `max_new_tokens` must be <=" in error_str - or "instance type with more CPU capacity or memory" in error_str - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"SagemakerException - {error_str}", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif custom_llm_provider == "vertex_ai": - if ( - "Vertex AI API has not been used in project" in error_str - or "Unable to find your project" in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "None Unknown Error." in error_str - or "Content has no parts." in error_str - ): - exception_mapping_worked = True - raise APIError( - message=f"VertexAIException - {error_str}", - status_code=500, - model=model, - llm_provider="vertex_ai", - request=original_exception.request, - litellm_debug_info=extra_information, - ) - elif "403" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "The response was blocked." in error_str: - exception_mapping_worked = True - raise UnprocessableEntityError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=httpx.Response( - status_code=429, - request=httpx.Request( - method="POST", - url=" https://cloud.google.com/vertex-ai/", - ), - ), - ) - elif ( - "429 Quota exceeded" in error_str - or "IndexError: list index out of range" in error_str - or "429 Unable to submit request because the service is temporarily out of capacity." - in error_str - ): - exception_mapping_worked = True - raise RateLimitError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=httpx.Response( - status_code=429, - request=httpx.Request( - method="POST", - url=" https://cloud.google.com/vertex-ai/", - ), - ), - ) - if hasattr(original_exception, "status_code"): - if original_exception.status_code == 400: - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=original_exception.response, - ) - if original_exception.status_code == 500: - exception_mapping_worked = True -> raise APIError( - message=f"VertexAIException - {error_str}", - status_code=500, - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - request=original_exception.request, -E litellm.exceptions.APIError: VertexAIException - Parameter to MergeFrom() must be instance of same class: expected got . - -../utils.py:8922: APIError - -During handling of the above exception, another exception occurred: - - def test_gemini_pro_vision(): - try: - load_vertex_ai_credentials() - litellm.set_verbose = True - litellm.num_retries = 3 -> resp = litellm.completion( - model="vertex_ai/gemini-1.5-flash-preview-0514", - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "Whats in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" - }, - }, - ], - } - ], - ) - -test_amazing_vertex_completion.py:510: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../utils.py:3289: in wrapper - return litellm.completion_with_retries(*args, **kwargs) -../main.py:2401: in completion_with_retries - return retryer(original_function, *args, **kwargs) -../proxy/myenv/lib/python3.11/site-packages/tenacity/__init__.py:379: in __call__ - do = self.iter(retry_state=retry_state) -../proxy/myenv/lib/python3.11/site-packages/tenacity/__init__.py:325: in iter - raise retry_exc.reraise() -../proxy/myenv/lib/python3.11/site-packages/tenacity/__init__.py:158: in reraise - raise self.last_attempt.result() -/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/concurrent/futures/_base.py:449: in result - return self.__get_result() -/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/concurrent/futures/_base.py:401: in __get_result - raise self._exception -../proxy/myenv/lib/python3.11/site-packages/tenacity/__init__.py:382: in __call__ - result = fn(*args, **kwargs) -../utils.py:3317: in wrapper - raise e -../utils.py:3211: in wrapper - result = original_function(*args, **kwargs) -../main.py:2368: in completion - raise exception_type( -../utils.py:9709: in exception_type - raise e -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -model = 'gemini-1.5-flash-preview-0514' -original_exception = VertexAIError("Parameter to MergeFrom() must be instance of same class: expected got .") -custom_llm_provider = 'vertex_ai' -completion_kwargs = {'acompletion': False, 'api_base': None, 'api_key': None, 'api_version': None, ...} -extra_kwargs = {'litellm_call_id': '7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', 'litellm_logging_obj': } - - def exception_type( - model, - original_exception, - custom_llm_provider, - completion_kwargs={}, - extra_kwargs={}, - ): - global user_logger_fn, liteDebuggerClient - exception_mapping_worked = False - if litellm.suppress_debug_info is False: - print() # noqa - print( # noqa - "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" # noqa - ) # noqa - print( # noqa - "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa - ) # noqa - print() # noqa - try: - if model: - error_str = str(original_exception) - if isinstance(original_exception, BaseException): - exception_type = type(original_exception).__name__ - else: - exception_type = "" - - ################################################################################ - # Common Extra information needed for all providers - # We pass num retries, api_base, vertex_deployment etc to the exception here - ################################################################################ - extra_information = "" - try: - _api_base = litellm.get_api_base( - model=model, optional_params=extra_kwargs - ) - messages = litellm.get_first_chars_messages(kwargs=completion_kwargs) - _vertex_project = extra_kwargs.get("vertex_project") - _vertex_location = extra_kwargs.get("vertex_location") - _metadata = extra_kwargs.get("metadata", {}) or {} - _model_group = _metadata.get("model_group") - _deployment = _metadata.get("deployment") - extra_information = f"\nModel: {model}" - if _api_base: - extra_information += f"\nAPI Base: {_api_base}" - if messages and len(messages) > 0: - extra_information += f"\nMessages: {messages}" - - if _model_group is not None: - extra_information += f"\nmodel_group: {_model_group}\n" - if _deployment is not None: - extra_information += f"\ndeployment: {_deployment}\n" - if _vertex_project is not None: - extra_information += f"\nvertex_project: {_vertex_project}\n" - if _vertex_location is not None: - extra_information += f"\nvertex_location: {_vertex_location}\n" - - # on litellm proxy add key name + team to exceptions - extra_information = _add_key_name_and_team_to_alert( - request_info=extra_information, metadata=_metadata - ) - except: - # DO NOT LET this Block raising the original exception - pass - - ################################################################################ - # End of Common Extra information Needed for all providers - ################################################################################ - - ################################################################################ - #################### Start of Provider Exception mapping #################### - ################################################################################ - - if "Request Timeout Error" in error_str or "Request timed out" in error_str: - exception_mapping_worked = True - raise Timeout( - message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - - if ( - custom_llm_provider == "openai" - or custom_llm_provider == "text-completion-openai" - or custom_llm_provider == "custom_openai" - or custom_llm_provider in litellm.openai_compatible_providers - ): - # custom_llm_provider is openai, make it OpenAI - if hasattr(original_exception, "message"): - message = original_exception.message - else: - message = str(original_exception) - if message is not None and isinstance(message, str): - message = message.replace("OPENAI", custom_llm_provider.upper()) - message = message.replace("openai", custom_llm_provider) - message = message.replace("OpenAI", custom_llm_provider) - if custom_llm_provider == "openai": - exception_provider = "OpenAI" + "Exception" - else: - exception_provider = ( - custom_llm_provider[0].upper() - + custom_llm_provider[1:] - + "Exception" - ) - - if "This model's maximum context length is" in error_str: - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "model_not_found" in error_str - ): - exception_mapping_worked = True - raise NotFoundError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "content_policy_violation" in error_str - ): - exception_mapping_worked = True - raise ContentPolicyViolationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "Incorrect API key provided" not in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "Request too large" in error_str: - raise RateLimitError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" - in error_str - ): - exception_mapping_worked = True - raise AuthenticationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "Mistral API raised a streaming error" in error_str: - exception_mapping_worked = True - _request = httpx.Request( - method="POST", url="https://api.openai.com/v1" - ) - raise APIError( - status_code=500, - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - request=_request, - litellm_debug_info=extra_information, - ) - elif hasattr(original_exception, "status_code"): - exception_mapping_worked = True - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 404: - exception_mapping_worked = True - raise NotFoundError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 422: - exception_mapping_worked = True - raise BadRequestError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 503: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 504: # gateway timeout error - exception_mapping_worked = True - raise Timeout( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - else: - exception_mapping_worked = True - raise APIError( - status_code=original_exception.status_code, - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - request=original_exception.request, - litellm_debug_info=extra_information, - ) - else: - # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors - raise APIConnectionError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - litellm_debug_info=extra_information, - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ) - elif custom_llm_provider == "anthropic": # one of the anthropics - if hasattr(original_exception, "message"): - if ( - "prompt is too long" in original_exception.message - or "prompt: length" in original_exception.message - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - if "Invalid API Key" in original_exception.message: - exception_mapping_worked = True - raise AuthenticationError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - if hasattr(original_exception, "status_code"): - print_verbose(f"status_code: {original_exception.status_code}") - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"AnthropicException - {original_exception.message}", - llm_provider="anthropic", - model=model, - response=original_exception.response, - ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 413 - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"AnthropicException - {original_exception.message}", - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"AnthropicException - {original_exception.message}", - model=model, - llm_provider="anthropic", - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"AnthropicException - {original_exception.message}", - llm_provider="anthropic", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 500: - exception_mapping_worked = True - raise APIError( - status_code=500, - message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.", - llm_provider="anthropic", - model=model, - request=original_exception.request, - ) - elif custom_llm_provider == "replicate": - if "Incorrect authentication token" in error_str: - exception_mapping_worked = True - raise AuthenticationError( - message=f"ReplicateException - {error_str}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif "input is too long" in error_str: - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"ReplicateException - {error_str}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif exception_type == "ModelError": - exception_mapping_worked = True - raise BadRequestError( - message=f"ReplicateException - {error_str}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif "Request was throttled" in error_str: - exception_mapping_worked = True - raise RateLimitError( - message=f"ReplicateException - {error_str}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 422 - or original_exception.status_code == 413 - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"ReplicateException - {original_exception.message}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"ReplicateException - {original_exception.message}", - model=model, - llm_provider="replicate", - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 500: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - exception_mapping_worked = True - raise APIError( - status_code=500, - message=f"ReplicateException - {str(original_exception)}", - llm_provider="replicate", - model=model, - request=httpx.Request( - method="POST", - url="https://api.replicate.com/v1/deployments", - ), - ) - elif custom_llm_provider == "watsonx": - if "token_quota_reached" in error_str: - exception_mapping_worked = True - raise RateLimitError( - message=f"WatsonxException: Rate Limit Errror - {error_str}", - llm_provider="watsonx", - model=model, - response=original_exception.response, - ) - elif custom_llm_provider == "predibase": - if "authorization denied for" in error_str: - exception_mapping_worked = True - - # Predibase returns the raw API Key in the response - this block ensures it's not returned in the exception - if ( - error_str is not None - and isinstance(error_str, str) - and "bearer" in error_str.lower() - ): - # only keep the first 10 chars after the occurnence of "bearer" - _bearer_token_start_index = error_str.lower().find("bearer") - error_str = error_str[: _bearer_token_start_index + 14] - error_str += "XXXXXXX" + '"' - - raise AuthenticationError( - message=f"PredibaseException: Authentication Error - {error_str}", - llm_provider="predibase", - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif custom_llm_provider == "bedrock": - if ( - "too many tokens" in error_str - or "expected maxLength:" in error_str - or "Input is too long" in error_str - or "prompt: length: 1.." in error_str - or "Too many input tokens" in error_str - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"BedrockException: Context Window Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif "Malformed input request" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "Unable to locate credentials" in error_str - or "The security token included in the request is invalid" - in error_str - ): - exception_mapping_worked = True - raise AuthenticationError( - message=f"BedrockException Invalid Authentication - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif "AccessDeniedException" in error_str: - exception_mapping_worked = True - raise PermissionDeniedError( - message=f"BedrockException PermissionDeniedError - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "throttlingException" in error_str - or "ThrottlingException" in error_str - ): - exception_mapping_worked = True - raise RateLimitError( - message=f"BedrockException: Rate Limit Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "Connect timeout on endpoint URL" in error_str - or "timed out" in error_str - ): - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException: Timeout Error - {error_str}", - model=model, - llm_provider="bedrock", - ) - elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 500: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=httpx.Response( - status_code=500, - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ), - ) - elif original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 400: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 404: - exception_mapping_worked = True - raise NotFoundError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 422: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 503: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 504: # gateway timeout error - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif custom_llm_provider == "sagemaker": - if "Unable to locate credentials" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"SagemakerException - {error_str}", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif ( - "Input validation error: `best_of` must be > 0 and <= 2" - in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif ( - "`inputs` tokens + `max_new_tokens` must be <=" in error_str - or "instance type with more CPU capacity or memory" in error_str - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"SagemakerException - {error_str}", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif custom_llm_provider == "vertex_ai": - if ( - "Vertex AI API has not been used in project" in error_str - or "Unable to find your project" in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "None Unknown Error." in error_str - or "Content has no parts." in error_str - ): - exception_mapping_worked = True - raise APIError( - message=f"VertexAIException - {error_str}", - status_code=500, - model=model, - llm_provider="vertex_ai", - request=original_exception.request, - litellm_debug_info=extra_information, - ) - elif "403" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "The response was blocked." in error_str: - exception_mapping_worked = True - raise UnprocessableEntityError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=httpx.Response( - status_code=429, - request=httpx.Request( - method="POST", - url=" https://cloud.google.com/vertex-ai/", - ), - ), - ) - elif ( - "429 Quota exceeded" in error_str - or "IndexError: list index out of range" in error_str - or "429 Unable to submit request because the service is temporarily out of capacity." - in error_str - ): - exception_mapping_worked = True - raise RateLimitError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=httpx.Response( - status_code=429, - request=httpx.Request( - method="POST", - url=" https://cloud.google.com/vertex-ai/", - ), - ), - ) - if hasattr(original_exception, "status_code"): - if original_exception.status_code == 400: - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=original_exception.response, - ) - if original_exception.status_code == 500: - exception_mapping_worked = True -> raise APIError( - message=f"VertexAIException - {error_str}", - status_code=500, - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - request=original_exception.request, -E litellm.exceptions.APIError: VertexAIException - Parameter to MergeFrom() must be instance of same class: expected got . - -../utils.py:8922: APIError - -During handling of the above exception, another exception occurred: - - def test_gemini_pro_vision(): - try: - load_vertex_ai_credentials() - litellm.set_verbose = True - litellm.num_retries = 3 - resp = litellm.completion( - model="vertex_ai/gemini-1.5-flash-preview-0514", - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "Whats in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" - }, - }, - ], - } - ], - ) - print(resp) - - prompt_tokens = resp.usage.prompt_tokens - - # DO Not DELETE this ASSERT - # Google counts the prompt tokens for us, we should ensure we use the tokens from the orignal response - assert prompt_tokens == 263 # the gemini api returns 263 to us - except litellm.RateLimitError as e: - pass - except Exception as e: - if "500 Internal error encountered.'" in str(e): - pass - else: -> pytest.fail(f"An exception occurred - {str(e)}") -E Failed: An exception occurred - VertexAIException - Parameter to MergeFrom() must be instance of same class: expected got . - -test_amazing_vertex_completion.py:540: Failed ----------------------------- Captured stdout setup ----------------------------- - ------------------------------ Captured stdout call ----------------------------- -loading vertex ai credentials -Read vertexai file path - - -Request to litellm: -litellm.completion(model='vertex_ai/gemini-1.5-flash-preview-0514', messages=[{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}]) - - -self.optional_params: {} -SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False -(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK -(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {} -Final returned optional params: {} -self.optional_params: {} -VERTEX AI: vertex_project=None; vertex_location=None -VERTEX AI: creds=; google application credentials: /var/folders/gf/5h3fnlwx40sdrycs4y5qzqx40000gn/T/tmpolsest5s - -Making VertexAI Gemini Pro / Pro Vision Call - -Processing input messages = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}] - -Request Sent from LiteLLM: -llm_model = GenerativeModel(gemini-1.5-flash-preview-0514) -response = llm_model.generate_content([{'role': 'user', 'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -]}]) - - - -Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new -LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. - -Logging Details: logger_fn - None | callable(logger_fn) - False - - -Request to litellm: -litellm.completion(model='vertex_ai/gemini-1.5-flash-preview-0514', messages=[{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}], litellm_call_id='7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', litellm_logging_obj=) - - -SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False -(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK -(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {} -Final returned optional params: {} -self.optional_params: {} -VERTEX AI: vertex_project=None; vertex_location=None -VERTEX AI: creds=; google application credentials: /var/folders/gf/5h3fnlwx40sdrycs4y5qzqx40000gn/T/tmpolsest5s - -Making VertexAI Gemini Pro / Pro Vision Call - -Processing input messages = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}] - -Request Sent from LiteLLM: -llm_model = GenerativeModel(gemini-1.5-flash-preview-0514) -response = llm_model.generate_content([{'role': 'user', 'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -]}]) - - - -Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new -LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. - -Logging Details: logger_fn - None | callable(logger_fn) - False -Logging Details LiteLLM-Failure Call -self.failure_callback: [] - - -Request to litellm: -litellm.completion(model='vertex_ai/gemini-1.5-flash-preview-0514', messages=[{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}], litellm_call_id='7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', litellm_logging_obj=) - - -SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False -(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK -(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {} -Final returned optional params: {} -self.optional_params: {} -VERTEX AI: vertex_project=None; vertex_location=None -VERTEX AI: creds=; google application credentials: /var/folders/gf/5h3fnlwx40sdrycs4y5qzqx40000gn/T/tmpolsest5s - -Making VertexAI Gemini Pro / Pro Vision Call - -Processing input messages = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}] - -Request Sent from LiteLLM: -llm_model = GenerativeModel(gemini-1.5-flash-preview-0514) -response = llm_model.generate_content([{'role': 'user', 'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -]}]) - - - -Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new -LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. - -Logging Details: logger_fn - None | callable(logger_fn) - False -Logging Details LiteLLM-Failure Call -self.failure_callback: [] - - -Request to litellm: -litellm.completion(model='vertex_ai/gemini-1.5-flash-preview-0514', messages=[{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}], litellm_call_id='7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', litellm_logging_obj=) - - -SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False -(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK -(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {} -Final returned optional params: {} -self.optional_params: {} -VERTEX AI: vertex_project=None; vertex_location=None -VERTEX AI: creds=; google application credentials: /var/folders/gf/5h3fnlwx40sdrycs4y5qzqx40000gn/T/tmpolsest5s - -Making VertexAI Gemini Pro / Pro Vision Call - -Processing input messages = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}] - -Request Sent from LiteLLM: -llm_model = GenerativeModel(gemini-1.5-flash-preview-0514) -response = llm_model.generate_content([{'role': 'user', 'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -]}]) - - - -Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new -LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. - -Logging Details: logger_fn - None | callable(logger_fn) - False -Logging Details LiteLLM-Failure Call -self.failure_callback: [] -=============================== warnings summary =============================== -../proxy/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 25 warnings - /Users/krrishdholakia/Documents/litellm/litellm/proxy/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning) - -../proxy/_types.py:255 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:255: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:342 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:342: PydanticDeprecatedSince20: `pydantic.config.Extra` is deprecated, use literal values instead (e.g. `extra='allow'`). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - extra = Extra.allow # Allow extra fields - -../proxy/_types.py:345 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:345: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:374 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:374: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:421 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:421: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:490 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:490: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:510 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:510: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:523 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:523: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:568 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:568: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:605 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:605: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:923 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:923: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:950 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:950: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:971 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:971: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../utils.py:60 - /Users/krrishdholakia/Documents/litellm/litellm/utils.py:60: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice. - with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f: - --- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html -=========================== short test summary info ============================ -FAILED test_amazing_vertex_completion.py::test_gemini_pro_vision - Failed: An... -======================== 1 failed, 39 warnings in 2.09s ======================== diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 047f0cb2e2..64e7741e2a 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -243,6 +243,7 @@ def test_completion_bedrock_claude_sts_oidc_auth(): except Exception as e: pytest.fail(f"Error occurred: {e}") + @pytest.mark.skipif( os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None, reason="Cannot run without being in CircleCI Runner", @@ -277,7 +278,15 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth(): except Exception as e: pytest.fail(f"Error occurred: {e}") -def test_bedrock_claude_3(): + +@pytest.mark.parametrize( + "image_url", + [ + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAL0AAAC9CAMAAADRCYwCAAAAh1BMVEX///8AAAD8/Pz5+fkEBAT39/cJCQn09PRNTU3y8vIMDAwzMzPe3t7v7+8QEBCOjo7FxcXR0dHn5+elpaWGhoYYGBivr686OjocHBy0tLQtLS1TU1PY2Ni6urpaWlpERER3d3ecnJxoaGiUlJRiYmIlJSU4ODhBQUFycnKAgIDBwcFnZ2chISE7EjuwAAAI/UlEQVR4nO1caXfiOgz1bhJIyAJhX1JoSzv8/9/3LNlpYd4rhX6o4/N8Z2lKM2cURZau5JsQEhERERERERERERERERERERHx/wBjhDPC3OGN8+Cc5JeMuheaETSdO8vZFyCScHtmz2CsktoeMn7rLM1u3h0PMAEhyYX7v/Q9wQvoGdB0hlbzm45lEq/wd6y6G9aezvBk9AXwp1r3LHJIRsh6s2maxaJpmvqgvkC7WFS3loUnaFJtKRVUCEoV/RpCnHRvAsesVQ1hw+vd7Mpo+424tLs72NplkvQgcdrsvXkW/zJWqH/fA0FT84M/xnQJt4to3+ZLuanbM6X5lfXKHosO9COgREqpCR5i86pf2zPS7j9tTj+9nO7bQz3+xGEyGW9zqgQ1tyQ/VsxEDvce/4dcUPNb5OD9yXvR4Z2QisuP0xiGWPnemgugU5q/troHhGEjIF5sTOyW648aC0TssuaaCEsYEIkGzjWXOp3A0vVsf6kgRyqaDk+T7DIVWrb58b2tT5xpUucKwodOD/5LbrZC1ws6YSaBZJ/8xlh+XZSYXaMJ2ezNqjB3IPXuehPcx2U6b4t1dS/xNdFzguUt8ie7arnPeyCZroxLHzGgGdqVcspwafizPWEXBee+9G1OaufGdvNng/9C+gwgZ3PH3r87G6zXTZ5D5De2G2DeFoANXfbACkT+fxBQ22YFsTTJF9hjFVO6VbqxZXko4WJ8s52P4PnuxO5KRzu0/hlix1ySt8iXjgaQ+4IHPA9nVzNkdduM9LFT/Aacj4FtKrHA7iAw602Vnht6R8Vq1IOS+wNMKLYqayAYfRuufQPGeGb7sZogQQoLZrGPgZ6KoYn70Iw30O92BNEDpvwouCFn6wH2uS+EhRb3WF/HObZk3HuxfRQM3Y/Of/VH0n4MKNHZDiZvO9+m/ABALfkOcuar/7nOo7B95ACGVAFaz4jMiJwJhdaHBkySmzlGTu82gr6FSTik2kJvLnY9nOd/D90qcH268m3I/cgI1xg1maE5CuZYaWLH+UHANCIck0yt7Mx5zBm5vVHXHwChsZ35kKqUpmo5Svq5/fzfAI5g2vDtFPYo1HiEA85QrDeGm9g//LG7K0scO3sdpj2CBDgCa+0OFs0bkvVgnnM/QBDwllOMm+cN7vMSHlB7Uu4haHKaTwgGkv8tlK+hP8fzmFuK/RQTpaLPWvbd58yWIo66HHM0OsPoPhVqmtaEVL7N+wYcTLTbb0DLdgp23Eyy2VYJ2N7bkLFAAibtoLPe5sLt6Oa2bvU+zyeMa8wrixO0gRTn9tO9NCSThTLGqcqtsDvphlfmx/cPBZVvw24jg1LE2lPuEo35Mhi58U0I/Ga8n5w+NS8i34MAQLos5B1u0xL1ZvCVYVRw/Fs2q53KLaXJMWwOZZ/4MPYV19bAHmgGDKB6f01xoeJKFbl63q9J34KdaVNPJWztQyRkzA3KNs1AdAEDowMxh10emXTCx75CkurtbY/ZpdNDGdsn2UcHKHsQ8Ai3WZi48IfkvtjOhsLpuIRSKZTX9FA4o+0d6o/zOWqQzVJMynL9NsxhSJOaourq6nBVQBueMSyubsX2xHrmuABZN2Ns9jr5nwLFlLF/2R6atjW/67Yd11YQ1Z+kA9Zk9dPTM/o6dVo6HHVgC0JR8oUfmI93T9u3gvTG94bAH02Y5xeqRcjuwnKCK6Q2+ajl8KXJ3GSh22P3Zfx6S+n008ROhJn+JRIUVu6o7OXl8w1SeyhuqNDwNI7SjbK08QrqPxS95jy4G7nCXVq6G3HNu0LtK5J0e226CfC005WKK9sVvfxI0eUbcnzutfhWe3rpZHM0nZ/ny/N8tanKYlQ6VEW5Xuym8yV1zZX58vwGhZp/5tFfhybZabdbrQYOs8F+xEhmPsb0/nki6kIyVvzZzUASiOrTfF+Sj9bXC7DoJxeiV8tjQL6loSd0yCx7YyB6rPdLx31U2qCG3F/oXIuDuqd6LFO+4DNIJuxFZqSsU0ea88avovFnWKRYFYRQDfCfcGaBCLn4M4A1ntJ5E57vicwqq2enaZEF5nokCYu9TbKqCC5yCDfL+GhLxT4w4xEJs+anqgou8DOY2q8FMryjb2MehC1dRJ9s4g9NXeTwPkWON4RH+FhIe0AWR/S9ekvQ+t70XHeimGF78LzuU7d7PwrswdIG2VpgF8C53qVQsTDtBJc4CdnkQPbnZY9mbPdDFra3PCXBBQ5QBn2aQqtyhvlyYM4Hb2/mdhsxCUen04GZVvIJZw5PAamMOmjzq8Q+dzAKLXDQ3RUZItWsg4t7W2DP+JDrJDymoMH7E5zQtuEpG03GTIjGCW3LQqOYEsXgFc78x76NeRwY6SNM+IfQoh6myJKRBIcLYxZcwscJ/gI2isTBty2Po9IkYzP0/SS4hGlxRjFAG5z1Jt1LckiB57yWvo35EaolbvA+6fBa24xodL2YjsPpTnj3JgJOqhcgOeLVsYYwoK0wjY+m1D3rGc40CukkaHnkEjarlXrF1B9M6ECQ6Ow0V7R7N4G3LfOHAXtymoyXOb4QhaYHJ/gNBJUkxclpSs7DNcgWWDDmM7Ke5MJpGuioe7w5EOvfTunUKRzOh7G2ylL+6ynHrD54oQO3//cN3yVO+5qMVsPZq0CZIOx4TlcJ8+Vz7V5waL+7WekzUpRFMTnnTlSCq3X5usi8qmIleW/rit1+oQZn1WGSU/sKBYEqMNh1mBOc6PhK8yCfKHdUNQk8o/G19ZPTs5MYfai+DLs5vmee37zEyyH48WW3XA6Xw6+Az8lMhci7N/KleToo7PtTKm+RA887Kqc6E9dyqL/QPTugzMHLbLZtJKqKLFfzVWRNJ63c+95uWT/F7R0U5dDVvuS409AJXhJvD0EwWaWdW8UN11u/7+umaYjT8mJtzZwP/MD4r57fihiHlC5fylHfaqnJdro+Dr7DajvO+vi2EwyD70s8nCH71nzIO1l5Zl+v1DMCb5ebvCMkGHvobXy/hPumGLyX0218/3RyD1GRLOuf9u/OGQyDmto32yMiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIv7GP8YjWPR/czH2AAAAAElFTkSuQmCC", + "https://avatars.githubusercontent.com/u/29436595?v=", + ], +) +def test_bedrock_claude_3(image_url): try: litellm.set_verbose = True data = { @@ -294,7 +303,7 @@ def test_bedrock_claude_3(): { "image_url": { "detail": "high", - "url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAL0AAAC9CAMAAADRCYwCAAAAh1BMVEX///8AAAD8/Pz5+fkEBAT39/cJCQn09PRNTU3y8vIMDAwzMzPe3t7v7+8QEBCOjo7FxcXR0dHn5+elpaWGhoYYGBivr686OjocHBy0tLQtLS1TU1PY2Ni6urpaWlpERER3d3ecnJxoaGiUlJRiYmIlJSU4ODhBQUFycnKAgIDBwcFnZ2chISE7EjuwAAAI/UlEQVR4nO1caXfiOgz1bhJIyAJhX1JoSzv8/9/3LNlpYd4rhX6o4/N8Z2lKM2cURZau5JsQEhERERERERERERERERERERHx/wBjhDPC3OGN8+Cc5JeMuheaETSdO8vZFyCScHtmz2CsktoeMn7rLM1u3h0PMAEhyYX7v/Q9wQvoGdB0hlbzm45lEq/wd6y6G9aezvBk9AXwp1r3LHJIRsh6s2maxaJpmvqgvkC7WFS3loUnaFJtKRVUCEoV/RpCnHRvAsesVQ1hw+vd7Mpo+424tLs72NplkvQgcdrsvXkW/zJWqH/fA0FT84M/xnQJt4to3+ZLuanbM6X5lfXKHosO9COgREqpCR5i86pf2zPS7j9tTj+9nO7bQz3+xGEyGW9zqgQ1tyQ/VsxEDvce/4dcUPNb5OD9yXvR4Z2QisuP0xiGWPnemgugU5q/troHhGEjIF5sTOyW648aC0TssuaaCEsYEIkGzjWXOp3A0vVsf6kgRyqaDk+T7DIVWrb58b2tT5xpUucKwodOD/5LbrZC1ws6YSaBZJ/8xlh+XZSYXaMJ2ezNqjB3IPXuehPcx2U6b4t1dS/xNdFzguUt8ie7arnPeyCZroxLHzGgGdqVcspwafizPWEXBee+9G1OaufGdvNng/9C+gwgZ3PH3r87G6zXTZ5D5De2G2DeFoANXfbACkT+fxBQ22YFsTTJF9hjFVO6VbqxZXko4WJ8s52P4PnuxO5KRzu0/hlix1ySt8iXjgaQ+4IHPA9nVzNkdduM9LFT/Aacj4FtKrHA7iAw602Vnht6R8Vq1IOS+wNMKLYqayAYfRuufQPGeGb7sZogQQoLZrGPgZ6KoYn70Iw30O92BNEDpvwouCFn6wH2uS+EhRb3WF/HObZk3HuxfRQM3Y/Of/VH0n4MKNHZDiZvO9+m/ABALfkOcuar/7nOo7B95ACGVAFaz4jMiJwJhdaHBkySmzlGTu82gr6FSTik2kJvLnY9nOd/D90qcH268m3I/cgI1xg1maE5CuZYaWLH+UHANCIck0yt7Mx5zBm5vVHXHwChsZ35kKqUpmo5Svq5/fzfAI5g2vDtFPYo1HiEA85QrDeGm9g//LG7K0scO3sdpj2CBDgCa+0OFs0bkvVgnnM/QBDwllOMm+cN7vMSHlB7Uu4haHKaTwgGkv8tlK+hP8fzmFuK/RQTpaLPWvbd58yWIo66HHM0OsPoPhVqmtaEVL7N+wYcTLTbb0DLdgp23Eyy2VYJ2N7bkLFAAibtoLPe5sLt6Oa2bvU+zyeMa8wrixO0gRTn9tO9NCSThTLGqcqtsDvphlfmx/cPBZVvw24jg1LE2lPuEo35Mhi58U0I/Ga8n5w+NS8i34MAQLos5B1u0xL1ZvCVYVRw/Fs2q53KLaXJMWwOZZ/4MPYV19bAHmgGDKB6f01xoeJKFbl63q9J34KdaVNPJWztQyRkzA3KNs1AdAEDowMxh10emXTCx75CkurtbY/ZpdNDGdsn2UcHKHsQ8Ai3WZi48IfkvtjOhsLpuIRSKZTX9FA4o+0d6o/zOWqQzVJMynL9NsxhSJOaourq6nBVQBueMSyubsX2xHrmuABZN2Ns9jr5nwLFlLF/2R6atjW/67Yd11YQ1Z+kA9Zk9dPTM/o6dVo6HHVgC0JR8oUfmI93T9u3gvTG94bAH02Y5xeqRcjuwnKCK6Q2+ajl8KXJ3GSh22P3Zfx6S+n008ROhJn+JRIUVu6o7OXl8w1SeyhuqNDwNI7SjbK08QrqPxS95jy4G7nCXVq6G3HNu0LtK5J0e226CfC005WKK9sVvfxI0eUbcnzutfhWe3rpZHM0nZ/ny/N8tanKYlQ6VEW5Xuym8yV1zZX58vwGhZp/5tFfhybZabdbrQYOs8F+xEhmPsb0/nki6kIyVvzZzUASiOrTfF+Sj9bXC7DoJxeiV8tjQL6loSd0yCx7YyB6rPdLx31U2qCG3F/oXIuDuqd6LFO+4DNIJuxFZqSsU0ea88avovFnWKRYFYRQDfCfcGaBCLn4M4A1ntJ5E57vicwqq2enaZEF5nokCYu9TbKqCC5yCDfL+GhLxT4w4xEJs+anqgou8DOY2q8FMryjb2MehC1dRJ9s4g9NXeTwPkWON4RH+FhIe0AWR/S9ekvQ+t70XHeimGF78LzuU7d7PwrswdIG2VpgF8C53qVQsTDtBJc4CdnkQPbnZY9mbPdDFra3PCXBBQ5QBn2aQqtyhvlyYM4Hb2/mdhsxCUen04GZVvIJZw5PAamMOmjzq8Q+dzAKLXDQ3RUZItWsg4t7W2DP+JDrJDymoMH7E5zQtuEpG03GTIjGCW3LQqOYEsXgFc78x76NeRwY6SNM+IfQoh6myJKRBIcLYxZcwscJ/gI2isTBty2Po9IkYzP0/SS4hGlxRjFAG5z1Jt1LckiB57yWvo35EaolbvA+6fBa24xodL2YjsPpTnj3JgJOqhcgOeLVsYYwoK0wjY+m1D3rGc40CukkaHnkEjarlXrF1B9M6ECQ6Ow0V7R7N4G3LfOHAXtymoyXOb4QhaYHJ/gNBJUkxclpSs7DNcgWWDDmM7Ke5MJpGuioe7w5EOvfTunUKRzOh7G2ylL+6ynHrD54oQO3//cN3yVO+5qMVsPZq0CZIOx4TlcJ8+Vz7V5waL+7WekzUpRFMTnnTlSCq3X5usi8qmIleW/rit1+oQZn1WGSU/sKBYEqMNh1mBOc6PhK8yCfKHdUNQk8o/G19ZPTs5MYfai+DLs5vmee37zEyyH48WW3XA6Xw6+Az8lMhci7N/KleToo7PtTKm+RA887Kqc6E9dyqL/QPTugzMHLbLZtJKqKLFfzVWRNJ63c+95uWT/F7R0U5dDVvuS409AJXhJvD0EwWaWdW8UN11u/7+umaYjT8mJtzZwP/MD4r57fihiHlC5fylHfaqnJdro+Dr7DajvO+vi2EwyD70s8nCH71nzIO1l5Zl+v1DMCb5ebvCMkGHvobXy/hPumGLyX0218/3RyD1GRLOuf9u/OGQyDmto32yMiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIv7GP8YjWPR/czH2AAAAAElFTkSuQmCC", + "url": image_url, }, "type": "image_url", }, @@ -313,7 +322,6 @@ def test_bedrock_claude_3(): # Add any assertions here to check the response assert len(response.choices) > 0 assert len(response.choices[0].message.content) > 0 - except RateLimitError: pass except Exception as e: @@ -552,7 +560,7 @@ def test_bedrock_ptu(): assert "url" in mock_client_post.call_args.kwargs assert ( mock_client_post.call_args.kwargs["url"] - == "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke" + == "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/converse" ) mock_client_post.assert_called_once() diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 5feacecd2e..98898052b1 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -300,7 +300,11 @@ def test_completion_claude_3(): pytest.fail(f"Error occurred: {e}") -def test_completion_claude_3_function_call(): +@pytest.mark.parametrize( + "model", + ["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"], +) +def test_completion_claude_3_function_call(model): litellm.set_verbose = True tools = [ { @@ -331,13 +335,14 @@ def test_completion_claude_3_function_call(): try: # test without max tokens response = completion( - model="anthropic/claude-3-opus-20240229", + model=model, messages=messages, tools=tools, tool_choice={ "type": "function", "function": {"name": "get_current_weather"}, }, + drop_params=True, ) # Add any assertions, here to check response args @@ -364,10 +369,11 @@ def test_completion_claude_3_function_call(): ) # In the second response, Claude should deduce answer from tool results second_response = completion( - model="anthropic/claude-3-opus-20240229", + model=model, messages=messages, tools=tools, tool_choice="auto", + drop_params=True, ) print(second_response) except Exception as e: @@ -1398,7 +1404,6 @@ def test_hf_test_completion_tgi(): def mock_post(url, data=None, json=None, headers=None): - print(f"url={url}") if "text-classification" in url: raise Exception("Model not found") @@ -2241,9 +2246,6 @@ def test_re_use_openaiClient(): pytest.fail("got Exception", e) -# test_re_use_openaiClient() - - def test_completion_azure(): try: print("azure gpt-3.5 test\n\n") diff --git a/litellm/tests/test_prompt_factory.py b/litellm/tests/test_prompt_factory.py index 2fc04ec528..b3aafab6e6 100644 --- a/litellm/tests/test_prompt_factory.py +++ b/litellm/tests/test_prompt_factory.py @@ -15,6 +15,7 @@ from litellm.llms.prompt_templates.factory import ( claude_2_1_pt, llama_2_chat_pt, prompt_factory, + _bedrock_tools_pt, ) @@ -128,3 +129,27 @@ def test_anthropic_messages_pt(): # codellama_prompt_format() +def test_bedrock_tool_calling_pt(): + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + converted_tools = _bedrock_tools_pt(tools=tools) + + print(converted_tools) diff --git a/litellm/tests/test_proxy_exception_mapping.py b/litellm/tests/test_proxy_exception_mapping.py index ccd071d01e..4988426616 100644 --- a/litellm/tests/test_proxy_exception_mapping.py +++ b/litellm/tests/test_proxy_exception_mapping.py @@ -210,7 +210,9 @@ def test_chat_completion_exception_any_model(client): ) assert isinstance(openai_exception, openai.BadRequestError) _error_message = openai_exception.message - assert "chat_completion: Invalid model name passed in model=Lite-GPT-12" in str(_error_message) + assert "chat_completion: Invalid model name passed in model=Lite-GPT-12" in str( + _error_message + ) except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") @@ -238,7 +240,9 @@ def test_embedding_exception_any_model(client): print("Exception raised=", openai_exception) assert isinstance(openai_exception, openai.BadRequestError) _error_message = openai_exception.message - assert "embeddings: Invalid model name passed in model=Lite-GPT-12" in str(_error_message) + assert "embeddings: Invalid model name passed in model=Lite-GPT-12" in str( + _error_message + ) except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index c24de601f5..a5e098b027 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1284,18 +1284,18 @@ async def test_completion_replicate_llama3_streaming(sync_mode): # pytest.fail(f"Error occurred: {e}") -@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize("sync_mode", [True]) # False @pytest.mark.parametrize( "model", [ - # "bedrock/cohere.command-r-plus-v1:0", - # "anthropic.claude-3-sonnet-20240229-v1:0", - # "anthropic.claude-instant-v1", - # "bedrock/ai21.j2-mid", - # "mistral.mistral-7b-instruct-v0:2", - # "bedrock/amazon.titan-tg1-large", - # "meta.llama3-8b-instruct-v1:0", - "cohere.command-text-v14" + "bedrock/cohere.command-r-plus-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-instant-v1", + "bedrock/ai21.j2-mid", + "mistral.mistral-7b-instruct-v0:2", + "bedrock/amazon.titan-tg1-large", + "meta.llama3-8b-instruct-v1:0", + "cohere.command-text-v14", ], ) @pytest.mark.asyncio diff --git a/litellm/types/files.py b/litellm/types/files.py new file mode 100644 index 0000000000..2da8fe4806 --- /dev/null +++ b/litellm/types/files.py @@ -0,0 +1,281 @@ +from enum import Enum +from types import MappingProxyType +from typing import List, Set + +""" +Base Enums/Consts +""" + + +class FileType(Enum): + AAC = "AAC" + CSV = "CSV" + DOC = "DOC" + DOCX = "DOCX" + FLAC = "FLAC" + FLV = "FLV" + GIF = "GIF" + GOOGLE_DOC = "GOOGLE_DOC" + GOOGLE_DRAWINGS = "GOOGLE_DRAWINGS" + GOOGLE_SHEETS = "GOOGLE_SHEETS" + GOOGLE_SLIDES = "GOOGLE_SLIDES" + HEIC = "HEIC" + HEIF = "HEIF" + HTML = "HTML" + JPEG = "JPEG" + JSON = "JSON" + M4A = "M4A" + M4V = "M4V" + MOV = "MOV" + MP3 = "MP3" + MP4 = "MP4" + MPEG = "MPEG" + MPEGPS = "MPEGPS" + MPG = "MPG" + MPA = "MPA" + MPGA = "MPGA" + OGG = "OGG" + OPUS = "OPUS" + PDF = "PDF" + PCM = "PCM" + PNG = "PNG" + PPT = "PPT" + PPTX = "PPTX" + RTF = "RTF" + THREE_GPP = "3GPP" + TXT = "TXT" + WAV = "WAV" + WEBM = "WEBM" + WEBP = "WEBP" + WMV = "WMV" + XLS = "XLS" + XLSX = "XLSX" + + +FILE_EXTENSIONS: MappingProxyType[FileType, List[str]] = MappingProxyType( + { + FileType.AAC: ["aac"], + FileType.CSV: ["csv"], + FileType.DOC: ["doc"], + FileType.DOCX: ["docx"], + FileType.FLAC: ["flac"], + FileType.FLV: ["flv"], + FileType.GIF: ["gif"], + FileType.GOOGLE_DOC: ["gdoc"], + FileType.GOOGLE_DRAWINGS: ["gdraw"], + FileType.GOOGLE_SHEETS: ["gsheet"], + FileType.GOOGLE_SLIDES: ["gslides"], + FileType.HEIC: ["heic"], + FileType.HEIF: ["heif"], + FileType.HTML: ["html", "htm"], + FileType.JPEG: ["jpeg", "jpg"], + FileType.JSON: ["json"], + FileType.M4A: ["m4a"], + FileType.M4V: ["m4v"], + FileType.MOV: ["mov"], + FileType.MP3: ["mp3"], + FileType.MP4: ["mp4"], + FileType.MPEG: ["mpeg"], + FileType.MPEGPS: ["mpegps"], + FileType.MPG: ["mpg"], + FileType.MPA: ["mpa"], + FileType.MPGA: ["mpga"], + FileType.OGG: ["ogg"], + FileType.OPUS: ["opus"], + FileType.PDF: ["pdf"], + FileType.PCM: ["pcm"], + FileType.PNG: ["png"], + FileType.PPT: ["ppt"], + FileType.PPTX: ["pptx"], + FileType.RTF: ["rtf"], + FileType.THREE_GPP: ["3gpp"], + FileType.TXT: ["txt"], + FileType.WAV: ["wav"], + FileType.WEBM: ["webm"], + FileType.WEBP: ["webp"], + FileType.WMV: ["wmv"], + FileType.XLS: ["xls"], + FileType.XLSX: ["xlsx"], + } +) + +FILE_MIME_TYPES: MappingProxyType[FileType, str] = MappingProxyType( + { + FileType.AAC: "audio/aac", + FileType.CSV: "text/csv", + FileType.DOC: "application/msword", + FileType.DOCX: "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + FileType.FLAC: "audio/flac", + FileType.FLV: "video/x-flv", + FileType.GIF: "image/gif", + FileType.GOOGLE_DOC: "application/vnd.google-apps.document", + FileType.GOOGLE_DRAWINGS: "application/vnd.google-apps.drawing", + FileType.GOOGLE_SHEETS: "application/vnd.google-apps.spreadsheet", + FileType.GOOGLE_SLIDES: "application/vnd.google-apps.presentation", + FileType.HEIC: "image/heic", + FileType.HEIF: "image/heif", + FileType.HTML: "text/html", + FileType.JPEG: "image/jpeg", + FileType.JSON: "application/json", + FileType.M4A: "audio/x-m4a", + FileType.M4V: "video/x-m4v", + FileType.MOV: "video/quicktime", + FileType.MP3: "audio/mpeg", + FileType.MP4: "video/mp4", + FileType.MPEG: "video/mpeg", + FileType.MPEGPS: "video/mpegps", + FileType.MPG: "video/mpg", + FileType.MPA: "audio/m4a", + FileType.MPGA: "audio/mpga", + FileType.OGG: "audio/ogg", + FileType.OPUS: "audio/opus", + FileType.PDF: "application/pdf", + FileType.PCM: "audio/pcm", + FileType.PNG: "image/png", + FileType.PPT: "application/vnd.ms-powerpoint", + FileType.PPTX: "application/vnd.openxmlformats-officedocument.presentationml.presentation", + FileType.RTF: "application/rtf", + FileType.THREE_GPP: "video/3gpp", + FileType.TXT: "text/plain", + FileType.WAV: "audio/wav", + FileType.WEBM: "video/webm", + FileType.WEBP: "image/webp", + FileType.WMV: "video/wmv", + FileType.XLS: "application/vnd.ms-excel", + FileType.XLSX: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + } +) + +""" +Util Functions +""" + + +def get_file_mime_type_from_extension(extension: str) -> str: + for file_type, extensions in FILE_EXTENSIONS.items(): + if extension in extensions: + return FILE_MIME_TYPES[file_type] + raise ValueError(f"Unknown mime type for extension: {extension}") + + +def get_file_extension_from_mime_type(mime_type: str) -> str: + for file_type, mime in FILE_MIME_TYPES.items(): + if mime == mime_type: + return FILE_EXTENSIONS[file_type][0] + raise ValueError(f"Unknown extension for mime type: {mime_type}") + + +def get_file_type_from_extension(extension: str) -> FileType: + for file_type, extensions in FILE_EXTENSIONS.items(): + if extension in extensions: + return file_type + + raise ValueError(f"Unknown file type for extension: {extension}") + + +def get_file_extension_for_file_type(file_type: FileType) -> str: + return FILE_EXTENSIONS[file_type][0] + + +def get_file_mime_type_for_file_type(file_type: FileType) -> str: + return FILE_MIME_TYPES[file_type] + + +""" +FileType Type Groupings (Videos, Images, etc) +""" + +# Images +IMAGE_FILE_TYPES = { + FileType.PNG, + FileType.JPEG, + FileType.GIF, + FileType.WEBP, + FileType.HEIC, + FileType.HEIF, +} + + +def is_image_file_type(file_type): + return file_type in IMAGE_FILE_TYPES + + +# Videos +VIDEO_FILE_TYPES = { + FileType.MOV, + FileType.MP4, + FileType.MPEG, + FileType.M4V, + FileType.FLV, + FileType.MPEGPS, + FileType.MPG, + FileType.WEBM, + FileType.WMV, + FileType.THREE_GPP, +} + + +def is_video_file_type(file_type): + return file_type in VIDEO_FILE_TYPES + + +# Audio +AUDIO_FILE_TYPES = { + FileType.AAC, + FileType.FLAC, + FileType.MP3, + FileType.MPA, + FileType.MPGA, + FileType.OPUS, + FileType.PCM, + FileType.WAV, +} + + +def is_audio_file_type(file_type): + return file_type in AUDIO_FILE_TYPES + + +# Text +TEXT_FILE_TYPES = {FileType.CSV, FileType.HTML, FileType.RTF, FileType.TXT} + + +def is_text_file_type(file_type): + return file_type in TEXT_FILE_TYPES + + +""" +Other FileType Groupings +""" +# Accepted file types for GEMINI 1.5 through Vertex AI +# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#gemini-send-multimodal-samples-images-nodejs +GEMINI_1_5_ACCEPTED_FILE_TYPES: Set[FileType] = { + # Image + FileType.PNG, + FileType.JPEG, + # Audio + FileType.AAC, + FileType.FLAC, + FileType.MP3, + FileType.MPA, + FileType.MPGA, + FileType.OPUS, + FileType.PCM, + FileType.WAV, + # Video + FileType.FLV, + FileType.MOV, + FileType.MPEG, + FileType.MPEGPS, + FileType.MPG, + FileType.MP4, + FileType.WEBM, + FileType.WMV, + FileType.THREE_GPP, + # PDF + FileType.PDF, +} + + +def is_gemini_1_5_accepted_file_type(file_type: FileType) -> bool: + return file_type in GEMINI_1_5_ACCEPTED_FILE_TYPES diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index 0c82596827..b06075092f 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -1,4 +1,4 @@ -from typing import TypedDict, Any, Union, Optional +from typing import TypedDict, Any, Union, Optional, Literal, List import json from typing_extensions import ( Self, @@ -11,10 +11,137 @@ from typing_extensions import ( ) +class SystemContentBlock(TypedDict): + text: str + + +class ImageSourceBlock(TypedDict): + bytes: Optional[str] # base 64 encoded string + + +class ImageBlock(TypedDict): + format: Literal["png", "jpeg", "gif", "webp"] + source: ImageSourceBlock + + +class ToolResultContentBlock(TypedDict, total=False): + image: ImageBlock + json: dict + text: str + + +class ToolResultBlock(TypedDict, total=False): + content: Required[List[ToolResultContentBlock]] + toolUseId: Required[str] + status: Literal["success", "error"] + + +class ToolUseBlock(TypedDict): + input: dict + name: str + toolUseId: str + + +class ContentBlock(TypedDict, total=False): + text: str + image: ImageBlock + toolResult: ToolResultBlock + toolUse: ToolUseBlock + + +class MessageBlock(TypedDict): + content: List[ContentBlock] + role: Literal["user", "assistant"] + + +class ConverseMetricsBlock(TypedDict): + latencyMs: float # time in ms + + +class ConverseResponseOutputBlock(TypedDict): + message: Optional[MessageBlock] + + +class ConverseTokenUsageBlock(TypedDict): + inputTokens: int + outputTokens: int + totalTokens: int + + +class ConverseResponseBlock(TypedDict): + additionalModelResponseFields: dict + metrics: ConverseMetricsBlock + output: ConverseResponseOutputBlock + stopReason: ( + str # end_turn | tool_use | max_tokens | stop_sequence | content_filtered + ) + usage: ConverseTokenUsageBlock + + +class ToolInputSchemaBlock(TypedDict): + json: Optional[dict] + + +class ToolSpecBlock(TypedDict, total=False): + inputSchema: Required[ToolInputSchemaBlock] + name: Required[str] + description: str + + +class ToolBlock(TypedDict): + toolSpec: Optional[ToolSpecBlock] + + +class SpecificToolChoiceBlock(TypedDict): + name: str + + +class ToolChoiceValuesBlock(TypedDict, total=False): + any: dict + auto: dict + tool: SpecificToolChoiceBlock + + +class ToolConfigBlock(TypedDict, total=False): + tools: Required[List[ToolBlock]] + toolChoice: Union[str, ToolChoiceValuesBlock] + + +class InferenceConfig(TypedDict, total=False): + maxTokens: int + stopSequences: List[str] + temperature: float + topP: float + + +class ToolBlockDeltaEvent(TypedDict): + input: str + + +class ContentBlockDeltaEvent(TypedDict, total=False): + """ + Either 'text' or 'toolUse' will be specified for Converse API streaming response. + """ + + text: str + toolUse: ToolBlockDeltaEvent + + +class RequestObject(TypedDict, total=False): + additionalModelRequestFields: dict + additionalModelResponseFieldPaths: List[str] + inferenceConfig: InferenceConfig + messages: Required[List[MessageBlock]] + system: List[SystemContentBlock] + toolConfig: ToolConfigBlock + + class GenericStreamingChunk(TypedDict): text: Required[str] + tool_str: Required[str] is_finished: Required[bool] finish_reason: Required[str] + usage: Optional[ConverseTokenUsageBlock] class Document(TypedDict): diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index bc0c82434f..7861e394cd 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -293,3 +293,20 @@ class ListBatchRequest(TypedDict, total=False): extra_headers: Optional[Dict[str, str]] extra_body: Optional[Dict[str, str]] timeout: Optional[float] + + +class ChatCompletionToolCallFunctionChunk(TypedDict): + name: str + arguments: str + + +class ChatCompletionToolCallChunk(TypedDict): + id: str + type: Literal["function"] + function: ChatCompletionToolCallFunctionChunk + + +class ChatCompletionResponseMessage(TypedDict, total=False): + content: Optional[str] + tool_calls: List[ChatCompletionToolCallChunk] + role: Literal["assistant"] diff --git a/litellm/types/services.py b/litellm/types/services.py index b694ca8078..9c3c2120eb 100644 --- a/litellm/types/services.py +++ b/litellm/types/services.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field from typing import Optional -class ServiceTypes(enum.Enum): +class ServiceTypes(str, enum.Enum): """ Enum for litellm + litellm-adjacent services (redis/postgres/etc.) """ diff --git a/litellm/utils.py b/litellm/utils.py index e8d44e87d0..d8c0e48af1 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -239,6 +239,8 @@ def map_finish_reason( return "length" elif finish_reason == "tool_use": # anthropic return "tool_calls" + elif finish_reason == "content_filtered": + return "content_filter" return finish_reason @@ -5655,19 +5657,29 @@ def get_optional_params( optional_params["stream"] = stream elif "anthropic" in model: _check_valid_arg(supported_params=supported_params) - # anthropic params on bedrock - # \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}" - if model.startswith("anthropic.claude-3"): - optional_params = ( - litellm.AmazonAnthropicClaude3Config().map_openai_params( + if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route. + if model.startswith("anthropic.claude-3"): + optional_params = ( + litellm.AmazonAnthropicClaude3Config().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + ) + ) + else: + optional_params = litellm.AmazonAnthropicConfig().map_openai_params( non_default_params=non_default_params, optional_params=optional_params, ) - ) - else: - optional_params = litellm.AmazonAnthropicConfig().map_openai_params( + else: # bedrock httpx route + optional_params = litellm.AmazonConverseConfig().map_openai_params( + model=model, non_default_params=non_default_params, optional_params=optional_params, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), ) elif "amazon" in model: # amazon titan llms _check_valid_arg(supported_params=supported_params) @@ -6445,20 +6457,7 @@ def get_supported_openai_params( - None if unmapped """ if custom_llm_provider == "bedrock": - if model.startswith("anthropic.claude-3"): - return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() - elif model.startswith("anthropic"): - return litellm.AmazonAnthropicConfig().get_supported_openai_params() - elif model.startswith("ai21"): - return ["max_tokens", "temperature", "top_p", "stream"] - elif model.startswith("amazon"): - return ["max_tokens", "temperature", "stop", "top_p", "stream"] - elif model.startswith("meta"): - return ["max_tokens", "temperature", "top_p", "stream"] - elif model.startswith("cohere"): - return ["stream", "temperature", "max_tokens"] - elif model.startswith("mistral"): - return ["max_tokens", "temperature", "stop", "top_p", "stream"] + return litellm.AmazonConverseConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "ollama": return litellm.OllamaConfig().get_supported_openai_params() elif custom_llm_provider == "ollama_chat": @@ -8558,7 +8557,11 @@ def exception_type( extra_information = f"\nModel: {model}" if _api_base: extra_information += f"\nAPI Base: `{_api_base}`" - if messages and len(messages) > 0: + if ( + messages + and len(messages) > 0 + and litellm.redact_messages_in_exceptions is False + ): extra_information += f"\nMessages: `{messages}`" if _model_group is not None: @@ -9124,7 +9127,7 @@ def exception_type( if "Unable to locate credentials" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"SagemakerException - {error_str}", + message=f"litellm.BadRequestError: SagemakerException - {error_str}", model=model, llm_provider="sagemaker", response=original_exception.response, @@ -9158,10 +9161,16 @@ def exception_type( ): exception_mapping_worked = True raise BadRequestError( - message=f"VertexAIException BadRequestError - {error_str}", + message=f"litellm.BadRequestError: VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", - response=original_exception.response, + response=httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ), litellm_debug_info=extra_information, ) elif ( @@ -9169,12 +9178,19 @@ def exception_type( or "Content has no parts." in error_str ): exception_mapping_worked = True - raise APIError( - message=f"VertexAIException APIError - {error_str}", + raise litellm.InternalServerError( + message=f"litellm.InternalServerError: VertexAIException - {error_str}", status_code=500, model=model, llm_provider="vertex_ai", - request=original_exception.request, + request=( + original_exception.request + if hasattr(original_exception, "request") + else httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ) + ), litellm_debug_info=extra_information, ) elif "403" in error_str: @@ -9183,7 +9199,13 @@ def exception_type( message=f"VertexAIException BadRequestError - {error_str}", model=model, llm_provider="vertex_ai", - response=original_exception.response, + response=httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ), litellm_debug_info=extra_information, ) elif "The response was blocked." in error_str: @@ -9230,12 +9252,18 @@ def exception_type( model=model, llm_provider="vertex_ai", litellm_debug_info=extra_information, - response=original_exception.response, + response=httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ), ) if original_exception.status_code == 500: exception_mapping_worked = True - raise APIError( - message=f"VertexAIException APIError - {error_str}", + raise litellm.InternalServerError( + message=f"VertexAIException InternalServerError - {error_str}", status_code=500, model=model, llm_provider="vertex_ai", @@ -11423,12 +11451,27 @@ class CustomStreamWrapper: if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "bedrock": + from litellm.types.llms.bedrock import GenericStreamingChunk + if self.received_finish_reason is not None: raise StopIteration - response_obj = self.handle_bedrock_stream(chunk) + response_obj: GenericStreamingChunk = chunk completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] + + if ( + self.stream_options + and self.stream_options.get("include_usage", False) is True + and response_obj["usage"] is not None + ): + self.sent_stream_usage = True + model_response.usage = litellm.Usage( + prompt_tokens=response_obj["usage"]["inputTokens"], + completion_tokens=response_obj["usage"]["outputTokens"], + total_tokens=response_obj["usage"]["totalTokens"], + ) elif self.custom_llm_provider == "sagemaker": print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") response_obj = self.handle_sagemaker_stream(chunk) @@ -11695,7 +11738,7 @@ class CustomStreamWrapper: and hasattr(model_response, "usage") and hasattr(model_response.usage, "prompt_tokens") ): - if self.sent_first_chunk == False: + if self.sent_first_chunk is False: completion_obj["role"] = "assistant" self.sent_first_chunk = True model_response.choices[0].delta = Delta(**completion_obj) @@ -11863,6 +11906,8 @@ class CustomStreamWrapper: def __next__(self): try: + if self.completion_stream is None: + self.fetch_sync_stream() while True: if ( isinstance(self.completion_stream, str) @@ -11937,6 +11982,14 @@ class CustomStreamWrapper: custom_llm_provider=self.custom_llm_provider, ) + def fetch_sync_stream(self): + if self.completion_stream is None and self.make_call is not None: + # Call make_call to get the completion stream + self.completion_stream = self.make_call(client=litellm.module_level_client) + self._stream_iter = self.completion_stream.__iter__() + + return self.completion_stream + async def fetch_stream(self): if self.completion_stream is None and self.make_call is not None: # Call make_call to get the completion stream diff --git a/pyproject.toml b/pyproject.toml index c346640396..a472ae1956 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,10 @@ description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" readme = "README.md" +packages = [ + { include = "litellm" }, + { include = "litellm/py.typed"}, +] [tool.poetry.urls] homepage = "https://litellm.ai" diff --git a/ruff.toml b/ruff.toml index 8aa9d70732..dfb323c1b3 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1 +1,3 @@ -ignore = ["F403", "F401"] \ No newline at end of file +ignore = ["F405"] +extend-select = ["E501"] +line-length = 120