Litellm remove circular imports (#7232)

* fix(utils.py): initial commit to remove circular imports - moves llmproviders to utils.py

* fix(router.py): fix 'litellm.EmbeddingResponse' import from router.py

'

* refactor: fix litellm.ModelResponse import on pass through endpoints

* refactor(litellm_logging.py): fix circular import for custom callbacks literal

* fix(factory.py): fix circular imports inside prompt factory

* fix(cost_calculator.py): fix circular import for 'litellm.Usage'

* fix(proxy_server.py): fix potential circular import with `litellm.Router'

* fix(proxy/utils.py): fix potential circular import in `litellm.Router`

* fix: remove circular imports in 'auth_checks' and 'guardrails/'

* fix(prompt_injection_detection.py): fix router impor t

* fix(vertex_passthrough_logging_handler.py): fix potential circular imports in vertex pass through

* fix(anthropic_pass_through_logging_handler.py): fix potential circular imports

* fix(slack_alerting.py-+-ollama_chat.py): fix modelresponse import

* fix(base.py): fix potential circular import

* fix(handler.py): fix potential circular ref in codestral + cohere handler's

* fix(azure.py): fix potential circular imports

* fix(gpt_transformation.py): fix modelresponse import

* fix(litellm_logging.py): add logging base class - simplify typing

makes it easy for other files to type check the logging obj without introducing circular imports

* fix(azure_ai/embed): fix potential circular import on handler.py

* fix(databricks/): fix potential circular imports in databricks/

* fix(vertex_ai/): fix potential circular imports on vertex ai embeddings

* fix(vertex_ai/image_gen): fix import

* fix(watsonx-+-bedrock): cleanup imports

* refactor(anthropic-pass-through-+-petals): cleanup imports

* refactor(huggingface/): cleanup imports

* fix(ollama-+-clarifai): cleanup circular imports

* fix(openai_like/): fix impor t

* fix(openai_like/): fix embedding handler

cleanup imports

* refactor(openai.py): cleanup imports

* fix(sagemaker/transformation.py): fix import

* ci(config.yml): add circular import test to ci/cd
This commit is contained in:
Krish Dholakia 2024-12-14 16:28:34 -08:00 committed by GitHub
parent 0dbf71291e
commit 516c2a6a70
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 489 additions and 256 deletions

View file

@ -817,6 +817,7 @@ jobs:
- run: python ./tests/documentation_tests/test_api_docs.py - run: python ./tests/documentation_tests/test_api_docs.py
- run: python ./tests/code_coverage_tests/ensure_async_clients_test.py - run: python ./tests/code_coverage_tests/ensure_async_clients_test.py
- run: python ./tests/code_coverage_tests/enforce_llms_folder_style.py - run: python ./tests/code_coverage_tests/enforce_llms_folder_style.py
- run: python ./tests/documentation_tests/test_circular_imports.py
- run: helm lint ./deploy/charts/litellm-helm - run: helm lint ./deploy/charts/litellm-helm
db_migration_disable_update_check: db_migration_disable_update_check:

View file

@ -474,12 +474,9 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail):
from detect_secrets import SecretsCollection from detect_secrets import SecretsCollection
from detect_secrets.settings import default_settings from detect_secrets.settings import default_settings
print("INSIDE SECRET DETECTION PRE-CALL HOOK!")
if await self.should_run_check(user_api_key_dict) is False: if await self.should_run_check(user_api_key_dict) is False:
return return
print("RUNNING CHECK!")
if "messages" in data and isinstance(data["messages"], list): if "messages" in data and isinstance(data["messages"], list):
for message in data["messages"]: for message in data["messages"]:
if "content" in message and isinstance(message["content"], str): if "content" in message and isinstance(message["content"], str):

View file

@ -32,7 +32,7 @@ from litellm.proxy._types import (
KeyManagementSettings, KeyManagementSettings,
LiteLLM_UpperboundKeyGenerateParams, LiteLLM_UpperboundKeyGenerateParams,
) )
from litellm.types.utils import StandardKeyGenerationConfig from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
import httpx import httpx
import dotenv import dotenv
from enum import Enum from enum import Enum
@ -838,71 +838,6 @@ model_list = (
) )
class LlmProviders(str, Enum):
OPENAI = "openai"
OPENAI_LIKE = "openai_like" # embedding only
JINA_AI = "jina_ai"
XAI = "xai"
CUSTOM_OPENAI = "custom_openai"
TEXT_COMPLETION_OPENAI = "text-completion-openai"
COHERE = "cohere"
COHERE_CHAT = "cohere_chat"
CLARIFAI = "clarifai"
ANTHROPIC = "anthropic"
ANTHROPIC_TEXT = "anthropic_text"
REPLICATE = "replicate"
HUGGINGFACE = "huggingface"
TOGETHER_AI = "together_ai"
OPENROUTER = "openrouter"
VERTEX_AI = "vertex_ai"
VERTEX_AI_BETA = "vertex_ai_beta"
GEMINI = "gemini"
AI21 = "ai21"
BASETEN = "baseten"
AZURE = "azure"
AZURE_TEXT = "azure_text"
AZURE_AI = "azure_ai"
SAGEMAKER = "sagemaker"
SAGEMAKER_CHAT = "sagemaker_chat"
BEDROCK = "bedrock"
VLLM = "vllm"
NLP_CLOUD = "nlp_cloud"
PETALS = "petals"
OOBABOOGA = "oobabooga"
OLLAMA = "ollama"
OLLAMA_CHAT = "ollama_chat"
DEEPINFRA = "deepinfra"
PERPLEXITY = "perplexity"
MISTRAL = "mistral"
GROQ = "groq"
NVIDIA_NIM = "nvidia_nim"
CEREBRAS = "cerebras"
AI21_CHAT = "ai21_chat"
VOLCENGINE = "volcengine"
CODESTRAL = "codestral"
TEXT_COMPLETION_CODESTRAL = "text-completion-codestral"
DEEPSEEK = "deepseek"
SAMBANOVA = "sambanova"
MARITALK = "maritalk"
VOYAGE = "voyage"
CLOUDFLARE = "cloudflare"
XINFERENCE = "xinference"
FIREWORKS_AI = "fireworks_ai"
FRIENDLIAI = "friendliai"
WATSONX = "watsonx"
WATSONX_TEXT = "watsonx_text"
TRITON = "triton"
PREDIBASE = "predibase"
DATABRICKS = "databricks"
EMPOWER = "empower"
GITHUB = "github"
CUSTOM = "custom"
LITELLM_PROXY = "litellm_proxy"
HOSTED_VLLM = "hosted_vllm"
LM_STUDIO = "lm_studio"
GALADRIEL = "galadriel"
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders) provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)

View file

@ -18,7 +18,7 @@ from litellm.types.llms.anthropic import (
AnthropicResponse, AnthropicResponse,
ContentBlockDelta, ContentBlockDelta,
) )
from litellm.types.utils import AdapterCompletionStreamWrapper from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse
class AnthropicAdapter(CustomLogger): class AnthropicAdapter(CustomLogger):
@ -41,7 +41,7 @@ class AnthropicAdapter(CustomLogger):
return translated_body return translated_body
def translate_completion_output_params( def translate_completion_output_params(
self, response: litellm.ModelResponse self, response: ModelResponse
) -> Optional[AnthropicResponse]: ) -> Optional[AnthropicResponse]:
return litellm.AnthropicExperimentalPassThroughConfig().translate_openai_response_to_anthropic( return litellm.AnthropicExperimentalPassThroughConfig().translate_openai_response_to_anthropic(

View file

@ -484,7 +484,7 @@ def completion_cost( # noqa: PLR0915
completion_characters: Optional[int] = None completion_characters: Optional[int] = None
cache_creation_input_tokens: Optional[int] = None cache_creation_input_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None cache_read_input_tokens: Optional[int] = None
cost_per_token_usage_object: Optional[litellm.Usage] = _get_usage_object( cost_per_token_usage_object: Optional[Usage] = _get_usage_object(
completion_response=completion_response completion_response=completion_response
) )
if completion_response is not None and ( if completion_response is not None and (
@ -492,7 +492,7 @@ def completion_cost( # noqa: PLR0915
or isinstance(completion_response, dict) or isinstance(completion_response, dict)
): # tts returns a custom class ): # tts returns a custom class
usage_obj: Optional[Union[dict, litellm.Usage]] = completion_response.get( # type: ignore usage_obj: Optional[Union[dict, Usage]] = completion_response.get( # type: ignore
"usage", {} "usage", {}
) )
if isinstance(usage_obj, BaseModel) and not isinstance( if isinstance(usage_obj, BaseModel) and not isinstance(

View file

@ -39,6 +39,7 @@ from litellm.proxy._types import (
VirtualKeyEvent, VirtualKeyEvent,
WebhookEvent, WebhookEvent,
) )
from litellm.router import Router
from litellm.types.integrations.slack_alerting import * from litellm.types.integrations.slack_alerting import *
from litellm.types.router import LiteLLM_Params from litellm.types.router import LiteLLM_Params
@ -93,7 +94,7 @@ class SlackAlerting(CustomBatchLogger):
alert_types: Optional[List[AlertType]] = None, alert_types: Optional[List[AlertType]] = None,
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]] = None, alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]] = None,
alerting_args: Optional[Dict] = None, alerting_args: Optional[Dict] = None,
llm_router: Optional[litellm.Router] = None, llm_router: Optional[Router] = None,
): ):
if alerting is not None: if alerting is not None:
self.alerting = alerting self.alerting = alerting

View file

@ -18,6 +18,7 @@ from pydantic import BaseModel
import litellm import litellm
from litellm import ( from litellm import (
_custom_logger_compatible_callbacks_literal,
json_logs, json_logs,
log_raw_request_response, log_raw_request_response,
turn_off_message_logging, turn_off_message_logging,
@ -41,6 +42,7 @@ from litellm.types.utils import (
CallTypes, CallTypes,
EmbeddingResponse, EmbeddingResponse,
ImageResponse, ImageResponse,
LiteLLMLoggingBaseClass,
ModelResponse, ModelResponse,
StandardCallbackDynamicParams, StandardCallbackDynamicParams,
StandardLoggingAdditionalHeaders, StandardLoggingAdditionalHeaders,
@ -190,7 +192,7 @@ in_memory_trace_id_cache = ServiceTraceIDCache()
in_memory_dynamic_logger_cache = DynamicLoggingCache() in_memory_dynamic_logger_cache = DynamicLoggingCache()
class Logging: class Logging(LiteLLMLoggingBaseClass):
global supabaseClient, promptLayerLogger, weightsBiasesLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger, logfireLogger, prometheusLogger, slack_app global supabaseClient, promptLayerLogger, weightsBiasesLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger, logfireLogger, prometheusLogger, slack_app
custom_pricing: bool = False custom_pricing: bool = False
stream_options = None stream_options = None
@ -2142,7 +2144,7 @@ def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
def _init_custom_logger_compatible_class( # noqa: PLR0915 def _init_custom_logger_compatible_class( # noqa: PLR0915
logging_integration: litellm._custom_logger_compatible_callbacks_literal, logging_integration: _custom_logger_compatible_callbacks_literal,
internal_usage_cache: Optional[DualCache], internal_usage_cache: Optional[DualCache],
llm_router: Optional[ llm_router: Optional[
Any Any
@ -2362,7 +2364,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
def get_custom_logger_compatible_class( # noqa: PLR0915 def get_custom_logger_compatible_class( # noqa: PLR0915
logging_integration: litellm._custom_logger_compatible_callbacks_literal, logging_integration: _custom_logger_compatible_callbacks_literal,
) -> Optional[CustomLogger]: ) -> Optional[CustomLogger]:
if logging_integration == "lago": if logging_integration == "lago":
for callback in _in_memory_loggers: for callback in _in_memory_loggers:

View file

@ -13,7 +13,6 @@ from jinja2.sandbox import ImmutableSandboxedEnvironment
import litellm import litellm
import litellm.types import litellm.types
import litellm.types.llms import litellm.types.llms
import litellm.types.llms.vertex_ai
from litellm import verbose_logger from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.types.completion import ( from litellm.types.completion import (
@ -40,6 +39,9 @@ from litellm.types.llms.openai import (
ChatCompletionUserMessage, ChatCompletionUserMessage,
OpenAIMessageContentListBlock, OpenAIMessageContentListBlock,
) )
from litellm.types.llms.vertex_ai import FunctionCall as VertexFunctionCall
from litellm.types.llms.vertex_ai import FunctionResponse as VertexFunctionResponse
from litellm.types.llms.vertex_ai import PartType as VertexPartType
from litellm.types.utils import GenericImageParsingChunk from litellm.types.utils import GenericImageParsingChunk
from .common_utils import convert_content_list_to_str, is_non_content_values_set from .common_utils import convert_content_list_to_str, is_non_content_values_set
@ -965,11 +967,11 @@ def infer_protocol_value(
def _gemini_tool_call_invoke_helper( def _gemini_tool_call_invoke_helper(
function_call_params: ChatCompletionToolCallFunctionChunk, function_call_params: ChatCompletionToolCallFunctionChunk,
) -> Optional[litellm.types.llms.vertex_ai.FunctionCall]: ) -> Optional[VertexFunctionCall]:
name = function_call_params.get("name", "") or "" name = function_call_params.get("name", "") or ""
arguments = function_call_params.get("arguments", "") arguments = function_call_params.get("arguments", "")
arguments_dict = json.loads(arguments) arguments_dict = json.loads(arguments)
function_call = litellm.types.llms.vertex_ai.FunctionCall( function_call = VertexFunctionCall(
name=name, name=name,
args=arguments_dict, args=arguments_dict,
) )
@ -978,7 +980,7 @@ def _gemini_tool_call_invoke_helper(
def convert_to_gemini_tool_call_invoke( def convert_to_gemini_tool_call_invoke(
message: ChatCompletionAssistantMessage, message: ChatCompletionAssistantMessage,
) -> List[litellm.types.llms.vertex_ai.PartType]: ) -> List[VertexPartType]:
""" """
OpenAI tool invokes: OpenAI tool invokes:
{ {
@ -1019,22 +1021,20 @@ def convert_to_gemini_tool_call_invoke(
- json.load the arguments - json.load the arguments
""" """
try: try:
_parts_list: List[litellm.types.llms.vertex_ai.PartType] = [] _parts_list: List[VertexPartType] = []
tool_calls = message.get("tool_calls", None) tool_calls = message.get("tool_calls", None)
function_call = message.get("function_call", None) function_call = message.get("function_call", None)
if tool_calls is not None: if tool_calls is not None:
for tool in tool_calls: for tool in tool_calls:
if "function" in tool: if "function" in tool:
gemini_function_call: Optional[ gemini_function_call: Optional[VertexFunctionCall] = (
litellm.types.llms.vertex_ai.FunctionCall _gemini_tool_call_invoke_helper(
] = _gemini_tool_call_invoke_helper( function_call_params=tool["function"]
function_call_params=tool["function"] )
) )
if gemini_function_call is not None: if gemini_function_call is not None:
_parts_list.append( _parts_list.append(
litellm.types.llms.vertex_ai.PartType( VertexPartType(function_call=gemini_function_call)
function_call=gemini_function_call
)
) )
else: # don't silently drop params. Make it clear to user what's happening. else: # don't silently drop params. Make it clear to user what's happening.
raise Exception( raise Exception(
@ -1047,11 +1047,7 @@ def convert_to_gemini_tool_call_invoke(
function_call_params=function_call function_call_params=function_call
) )
if gemini_function_call is not None: if gemini_function_call is not None:
_parts_list.append( _parts_list.append(VertexPartType(function_call=gemini_function_call))
litellm.types.llms.vertex_ai.PartType(
function_call=gemini_function_call
)
)
else: # don't silently drop params. Make it clear to user what's happening. else: # don't silently drop params. Make it clear to user what's happening.
raise Exception( raise Exception(
"function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format( "function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format(
@ -1070,7 +1066,7 @@ def convert_to_gemini_tool_call_invoke(
def convert_to_gemini_tool_call_result( def convert_to_gemini_tool_call_result(
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage], message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage],
last_message_with_tool_calls: Optional[dict], last_message_with_tool_calls: Optional[dict],
) -> litellm.types.llms.vertex_ai.PartType: ) -> VertexPartType:
""" """
OpenAI message with a tool result looks like: OpenAI message with a tool result looks like:
{ {
@ -1119,11 +1115,11 @@ def convert_to_gemini_tool_call_result(
# We can't determine from openai message format whether it's a successful or # We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template # error call result so default to the successful result template
_function_response = litellm.types.llms.vertex_ai.FunctionResponse( _function_response = VertexFunctionResponse(
name=name, response={"content": content_str} # type: ignore name=name, response={"content": content_str} # type: ignore
) )
_part = litellm.types.llms.vertex_ai.PartType(function_response=_function_response) _part = VertexPartType(function_response=_function_response)
return _part return _part

View file

@ -5,6 +5,12 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
import litellm import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
anthropic_messages_pt,
custom_prompt,
prompt_factory,
)
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AllAnthropicToolsValues, AllAnthropicToolsValues,
AnthopicMessagesAssistantMessageParam, AnthopicMessagesAssistantMessageParam,
@ -53,15 +59,9 @@ from litellm.types.llms.openai import (
ChatCompletionUserMessage, ChatCompletionUserMessage,
OpenAIMessageContent, OpenAIMessageContent,
) )
from litellm.types.utils import Choices, GenericStreamingChunk from litellm.types.utils import Choices, GenericStreamingChunk, ModelResponse, Usage
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ...base import BaseLLM from ...base import BaseLLM
from litellm.litellm_core_utils.prompt_templates.factory import (
anthropic_messages_pt,
custom_prompt,
prompt_factory,
)
class AnthropicExperimentalPassThroughConfig: class AnthropicExperimentalPassThroughConfig:
@ -338,7 +338,7 @@ class AnthropicExperimentalPassThroughConfig:
return "end_turn" return "end_turn"
def translate_openai_response_to_anthropic( def translate_openai_response_to_anthropic(
self, response: litellm.ModelResponse self, response: ModelResponse
) -> AnthropicResponse: ) -> AnthropicResponse:
## translate content block ## translate content block
anthropic_content = self._translate_openai_content_to_anthropic(choices=response.choices) # type: ignore anthropic_content = self._translate_openai_content_to_anthropic(choices=response.choices) # type: ignore
@ -347,7 +347,7 @@ class AnthropicExperimentalPassThroughConfig:
openai_finish_reason=response.choices[0].finish_reason # type: ignore openai_finish_reason=response.choices[0].finish_reason # type: ignore
) )
# extract usage # extract usage
usage: litellm.Usage = getattr(response, "usage") usage: Usage = getattr(response, "usage")
anthropic_usage = AnthropicResponseUsageBlock( anthropic_usage = AnthropicResponseUsageBlock(
input_tokens=usage.prompt_tokens or 0, input_tokens=usage.prompt_tokens or 0,
output_tokens=usage.completion_tokens or 0, output_tokens=usage.completion_tokens or 0,
@ -393,7 +393,7 @@ class AnthropicExperimentalPassThroughConfig:
return "text_delta", ContentTextBlockDelta(type="text_delta", text=text) return "text_delta", ContentTextBlockDelta(type="text_delta", text=text)
def translate_streaming_openai_response_to_anthropic( def translate_streaming_openai_response_to_anthropic(
self, response: litellm.ModelResponse self, response: ModelResponse
) -> Union[ContentBlockDelta, MessageBlockDelta]: ) -> Union[ContentBlockDelta, MessageBlockDelta]:
## base case - final chunk w/ finish reason ## base case - final chunk w/ finish reason
if response.choices[0].finish_reason is not None: if response.choices[0].finish_reason is not None:
@ -403,7 +403,7 @@ class AnthropicExperimentalPassThroughConfig:
), ),
) )
if getattr(response, "usage", None) is not None: if getattr(response, "usage", None) is not None:
litellm_usage_chunk: Optional[litellm.Usage] = response.usage # type: ignore litellm_usage_chunk: Optional[Usage] = response.usage # type: ignore
elif ( elif (
hasattr(response, "_hidden_params") hasattr(response, "_hidden_params")
and "usage" in response._hidden_params and "usage" in response._hidden_params

View file

@ -17,10 +17,14 @@ from litellm.llms.custom_httpx.http_handler import (
HTTPHandler, HTTPHandler,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.types.utils import EmbeddingResponse from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
LlmProviders,
ModelResponse,
)
from litellm.utils import ( from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
ModelResponse,
UnsupportedParamsError, UnsupportedParamsError,
convert_to_model_response_object, convert_to_model_response_object,
get_secret, get_secret,
@ -853,7 +857,7 @@ class AzureChatCompletion(BaseLLM):
client=None, client=None,
aembedding=None, aembedding=None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
) -> litellm.EmbeddingResponse: ) -> EmbeddingResponse:
if headers: if headers:
optional_params["extra_headers"] = headers optional_params["extra_headers"] = headers
if self._client_session is None: if self._client_session is None:
@ -963,7 +967,7 @@ class AzureChatCompletion(BaseLLM):
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
async_handler = get_async_httpx_client( async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.AZURE, llm_provider=LlmProviders.AZURE,
params=_params, params=_params,
) )
else: else:
@ -1242,11 +1246,11 @@ class AzureChatCompletion(BaseLLM):
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None, model_response: Optional[ImageResponse] = None,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
client=None, client=None,
aimg_generation=None, aimg_generation=None,
) -> litellm.ImageResponse: ) -> ImageResponse:
try: try:
if model and len(model) > 0: if model and len(model) > 0:
model = model model = model
@ -1510,7 +1514,7 @@ class AzureChatCompletion(BaseLLM):
) -> dict: ) -> dict:
client_session = ( client_session = (
litellm.aclient_session litellm.aclient_session
or get_async_httpx_client(llm_provider=litellm.LlmProviders.AZURE).client or get_async_httpx_client(llm_provider=LlmProviders.AZURE).client
) # handle dall-e-2 calls ) # handle dall-e-2 calls
if "gateway.ai.cloudflare.com" in api_base: if "gateway.ai.cloudflare.com" in api_base:

View file

@ -4,7 +4,11 @@ from typing import TYPE_CHECKING, Any, List, Optional, Type, Union
from httpx._models import Headers, Response from httpx._models import Headers, Response
import litellm import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
convert_to_azure_openai_messages,
)
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.types.utils import ModelResponse
from ....exceptions import UnsupportedParamsError from ....exceptions import UnsupportedParamsError
from ....types.llms.openai import ( from ....types.llms.openai import (
@ -14,9 +18,7 @@ from ....types.llms.openai import (
ChatCompletionToolParam, ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk, ChatCompletionToolParamFunctionChunk,
) )
from ...base_llm.transformation import BaseConfig from ...base_llm.transformation import BaseConfig
from litellm.litellm_core_utils.prompt_templates.factory import convert_to_azure_openai_messages
from ..common_utils import AzureOpenAIError from ..common_utils import AzureOpenAIError
if TYPE_CHECKING: if TYPE_CHECKING:
@ -26,6 +28,7 @@ if TYPE_CHECKING:
else: else:
LoggingClass = Any LoggingClass = Any
class AzureOpenAIConfig(BaseConfig): class AzureOpenAIConfig(BaseConfig):
""" """
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
@ -221,7 +224,7 @@ class AzureOpenAIConfig(BaseConfig):
self, self,
model: str, model: str,
raw_response: Response, raw_response: Response,
model_response: litellm.ModelResponse, model_response: ModelResponse,
logging_obj: LoggingClass, logging_obj: LoggingClass,
request_data: dict, request_data: dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
@ -230,7 +233,7 @@ class AzureOpenAIConfig(BaseConfig):
encoding: Any, encoding: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> litellm.ModelResponse: ) -> ModelResponse:
raise NotImplementedError( raise NotImplementedError(
"Azure OpenAI handler.py has custom logic for transforming response, as it uses the OpenAI SDK." "Azure OpenAI handler.py has custom logic for transforming response, as it uses the OpenAI SDK."
) )

View file

@ -89,7 +89,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
embedding_response = response.json() embedding_response = response.json()
embedding_headers = dict(response.headers) embedding_headers = dict(response.headers)
returned_response: litellm.EmbeddingResponse = convert_to_model_response_object( # type: ignore returned_response: EmbeddingResponse = convert_to_model_response_object( # type: ignore
response_object=embedding_response, response_object=embedding_response,
model_response_object=model_response, model_response_object=model_response,
response_type="embedding", response_type="embedding",
@ -104,7 +104,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
data: ImageEmbeddingRequest, data: ImageEmbeddingRequest,
timeout: float, timeout: float,
logging_obj, logging_obj,
model_response: litellm.EmbeddingResponse, model_response: EmbeddingResponse,
optional_params: dict, optional_params: dict,
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
@ -132,7 +132,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
embedding_response = response.json() embedding_response = response.json()
embedding_headers = dict(response.headers) embedding_headers = dict(response.headers)
returned_response: litellm.EmbeddingResponse = convert_to_model_response_object( # type: ignore returned_response: EmbeddingResponse = convert_to_model_response_object( # type: ignore
response_object=embedding_response, response_object=embedding_response,
model_response_object=model_response, model_response_object=model_response,
response_type="embedding", response_type="embedding",
@ -213,14 +213,14 @@ class AzureAIEmbedding(OpenAIChatCompletion):
input: List, input: List,
timeout: float, timeout: float,
logging_obj, logging_obj,
model_response: litellm.EmbeddingResponse, model_response: EmbeddingResponse,
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
client=None, client=None,
aembedding=None, aembedding=None,
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
) -> litellm.EmbeddingResponse: ) -> EmbeddingResponse:
""" """
- Separate image url from text - Separate image url from text
-> route image url call to `/image/embeddings` -> route image url call to `/image/embeddings`

View file

@ -5,6 +5,8 @@ import httpx
import requests import requests
import litellm import litellm
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.types.utils import ModelResponse, TextCompletionResponse
class BaseLLM: class BaseLLM:
@ -15,7 +17,7 @@ class BaseLLM:
self, self,
model: str, model: str,
response: Union[requests.Response, httpx.Response], response: Union[requests.Response, httpx.Response],
model_response: litellm.utils.ModelResponse, model_response: ModelResponse,
stream: bool, stream: bool,
logging_obj: Any, logging_obj: Any,
optional_params: dict, optional_params: dict,
@ -24,7 +26,7 @@ class BaseLLM:
messages: list, messages: list,
print_verbose, print_verbose,
encoding, encoding,
) -> Union[litellm.utils.ModelResponse, litellm.utils.CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
""" """
Helper function to process the response across sync + async completion calls Helper function to process the response across sync + async completion calls
""" """
@ -34,7 +36,7 @@ class BaseLLM:
self, self,
model: str, model: str,
response: Union[requests.Response, httpx.Response], response: Union[requests.Response, httpx.Response],
model_response: litellm.utils.TextCompletionResponse, model_response: TextCompletionResponse,
stream: bool, stream: bool,
logging_obj: Any, logging_obj: Any,
optional_params: dict, optional_params: dict,
@ -43,7 +45,7 @@ class BaseLLM:
messages: list, messages: list,
print_verbose, print_verbose,
encoding, encoding,
) -> Union[litellm.utils.TextCompletionResponse, litellm.utils.CustomStreamWrapper]: ) -> Union[TextCompletionResponse, CustomStreamWrapper]:
""" """
Helper function to process the response across sync + async completion calls Helper function to process the response across sync + async completion calls
""" """

View file

@ -32,6 +32,17 @@ from litellm import verbose_logger
from litellm.caching.caching import InMemoryCache from litellm.caching.caching import InMemoryCache
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.litellm_core_utils.prompt_templates.factory import (
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
cohere_message_pt,
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
)
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
@ -50,20 +61,10 @@ from litellm.types.llms.openai import (
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
) )
from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret from litellm.types.utils import ModelResponse, Usage
from litellm.utils import CustomStreamWrapper, get_secret
from ..base_aws_llm import BaseAWSLLM from ..base_aws_llm import BaseAWSLLM
from litellm.litellm_core_utils.prompt_templates.factory import (
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
cohere_message_pt,
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
)
from ..common_utils import BedrockError, ModelResponseIterator, get_bedrock_tool_name from ..common_utils import BedrockError, ModelResponseIterator, get_bedrock_tool_name
from .converse_transformation import AmazonConverseConfig from .converse_transformation import AmazonConverseConfig
@ -1317,7 +1318,7 @@ class MockResponseIterator: # for returning ai21 streaming responses
def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk: def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk:
try: try:
chunk_usage: litellm.Usage = getattr(chunk_data, "usage") chunk_usage: Usage = getattr(chunk_data, "usage")
text = chunk_data.choices[0].message.content or "" # type: ignore text = chunk_data.choices[0].message.content or "" # type: ignore
tool_use = None tool_use = None
if self.json_mode is True: if self.json_mode is True:

View file

@ -5,9 +5,11 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional,
import httpx import httpx
import litellm import litellm
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ( from litellm.types.utils import (
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
@ -152,7 +154,7 @@ class ClarifaiConfig(BaseConfig):
encoding: str, encoding: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> litellm.ModelResponse: ) -> ModelResponse:
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
api_key=api_key, api_key=api_key,

View file

@ -29,6 +29,7 @@ from litellm.llms.custom_httpx.http_handler import (
) )
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from litellm.types.llms.databricks import GenericStreamingChunk from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.types.utils import TextChoices
from litellm.utils import ( from litellm.utils import (
Choices, Choices,
CustomStreamWrapper, CustomStreamWrapper,
@ -169,7 +170,7 @@ class CodestralTextCompletion(BaseLLM):
raise TextCompletionCodestralError(message=response.text, status_code=422) raise TextCompletionCodestralError(message=response.text, status_code=422)
_original_choices = completion_response.get("choices", []) _original_choices = completion_response.get("choices", [])
_choices: List[litellm.utils.TextChoices] = [] _choices: List[TextChoices] = []
for choice in _original_choices: for choice in _original_choices:
# This is what 1 choice looks like from codestral API # This is what 1 choice looks like from codestral API
# { # {

View file

@ -17,6 +17,7 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.types.llms.bedrock import CohereEmbeddingRequest from litellm.types.llms.bedrock import CohereEmbeddingRequest
from litellm.types.utils import EmbeddingResponse
from litellm.utils import Choices, Message, ModelResponse, Usage from litellm.utils import Choices, Message, ModelResponse, Usage
from .transformation import CohereEmbeddingConfig from .transformation import CohereEmbeddingConfig
@ -118,7 +119,7 @@ async def async_embedding(
def embedding( def embedding(
model: str, model: str,
input: list, input: list,
model_response: litellm.EmbeddingResponse, model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
optional_params: dict, optional_params: dict,
headers: dict, headers: dict,

View file

@ -21,7 +21,7 @@ class DatabricksEmbeddingHandler(OpenAILikeEmbeddingHandler, DatabricksBase):
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
optional_params: dict, optional_params: dict,
model_response: Optional[litellm.utils.EmbeddingResponse] = None, model_response: Optional[EmbeddingResponse] = None,
client=None, client=None,
aembedding=None, aembedding=None,
custom_endpoint: Optional[bool] = None, custom_endpoint: Optional[bool] = None,

View file

@ -55,9 +55,7 @@ class ModelResponseIterator:
is_finished = True is_finished = True
finish_reason = processed_chunk.choices[0].finish_reason finish_reason = processed_chunk.choices[0].finish_reason
usage_chunk: Optional[litellm.Usage] = getattr( usage_chunk: Optional[Usage] = getattr(processed_chunk, "usage", None)
processed_chunk, "usage", None
)
if usage_chunk is not None: if usage_chunk is not None:
usage = ChatCompletionUsageBlock( usage = ChatCompletionUsageBlock(

View file

@ -24,6 +24,7 @@ import requests
import litellm import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
@ -36,8 +37,9 @@ from litellm.llms.huggingface.chat.transformation import (
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.completion import ChatCompletionMessageToolCallParam from litellm.types.completion import ChatCompletionMessageToolCallParam
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import EmbeddingResponse
from litellm.types.utils import Logprobs as TextCompletionLogprobs from litellm.types.utils import Logprobs as TextCompletionLogprobs
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage from litellm.types.utils import ModelResponse, Usage
from ...base import BaseLLM from ...base import BaseLLM
from ..common_utils import HuggingfaceError, hf_task_list, hf_tasks from ..common_utils import HuggingfaceError, hf_task_list, hf_tasks
@ -453,11 +455,11 @@ class Huggingface(BaseLLM):
def _process_embedding_response( def _process_embedding_response(
self, self,
embeddings: dict, embeddings: dict,
model_response: litellm.EmbeddingResponse, model_response: EmbeddingResponse,
model: str, model: str,
input: List, input: List,
encoding: Any, encoding: Any,
) -> litellm.EmbeddingResponse: ) -> EmbeddingResponse:
output_data = [] output_data = []
if "similarities" in embeddings: if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]: for idx, embedding in embeddings["similarities"]:
@ -583,7 +585,7 @@ class Huggingface(BaseLLM):
self, self,
model: str, model: str,
input: list, input: list,
model_response: litellm.EmbeddingResponse, model_response: EmbeddingResponse,
optional_params: dict, optional_params: dict,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
encoding: Callable, encoding: Callable,
@ -593,7 +595,7 @@ class Huggingface(BaseLLM):
aembedding: Optional[bool] = None, aembedding: Optional[bool] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
headers={}, headers={},
) -> litellm.EmbeddingResponse: ) -> EmbeddingResponse:
super().embedding() super().embedding()
headers = hf_chat_config.validate_environment( headers = hf_chat_config.validate_environment(
api_key=api_key, api_key=api_key,

View file

@ -8,10 +8,15 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx import httpx
import litellm import litellm
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage from litellm.types.utils import Choices, Message, ModelResponse, Usage
@ -407,7 +412,7 @@ class HuggingfaceChatConfig(BaseConfig):
def convert_to_model_response_object( # noqa: PLR0915 def convert_to_model_response_object( # noqa: PLR0915
self, self,
completion_response: Union[List[Dict[str, Any]], Dict[str, Any]], completion_response: Union[List[Dict[str, Any]], Dict[str, Any]],
model_response: litellm.ModelResponse, model_response: ModelResponse,
task: Optional[hf_tasks], task: Optional[hf_tasks],
optional_params: dict, optional_params: dict,
encoding: Any, encoding: Any,

View file

@ -14,11 +14,20 @@ import requests # type: ignore
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import ModelInfo, ProviderField, StreamingChoices from litellm.types.utils import (
EmbeddingResponse,
ModelInfo,
ModelResponse,
ProviderField,
StreamingChoices,
)
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import OllamaError from ..common_utils import OllamaError
from .transformation import OllamaConfig from .transformation import OllamaConfig
@ -53,7 +62,7 @@ def _convert_image(image):
# ollama implementation # ollama implementation
def get_ollama_response( def get_ollama_response(
model_response: litellm.ModelResponse, model_response: ModelResponse,
model: str, model: str,
prompt: str, prompt: str,
optional_params: dict, optional_params: dict,
@ -391,7 +400,7 @@ async def ollama_aembeddings(
api_base: str, api_base: str,
model: str, model: str,
prompts: List[str], prompts: List[str],
model_response: litellm.EmbeddingResponse, model_response: EmbeddingResponse,
optional_params: dict, optional_params: dict,
logging_obj: Any, logging_obj: Any,
encoding: Any, encoding: Any,
@ -479,7 +488,7 @@ def ollama_embeddings(
model: str, model: str,
prompts: list, prompts: list,
optional_params: dict, optional_params: dict,
model_response: litellm.EmbeddingResponse, model_response: EmbeddingResponse,
logging_obj: Any, logging_obj: Any,
encoding=None, encoding=None,
): ):

View file

@ -17,7 +17,7 @@ from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall from litellm.types.llms.openai import ChatCompletionAssistantToolCall
from litellm.types.utils import StreamingChoices from litellm.types.utils import ModelResponse, StreamingChoices
class OllamaError(Exception): class OllamaError(Exception):
@ -198,7 +198,7 @@ class OllamaChatConfig(OpenAIGPTConfig):
# ollama implementation # ollama implementation
def get_ollama_response( # noqa: PLR0915 def get_ollama_response( # noqa: PLR0915
model_response: litellm.ModelResponse, model_response: ModelResponse,
messages: list, messages: list,
optional_params: dict, optional_params: dict,
model: str, model: str,

View file

@ -28,24 +28,31 @@ import litellm
from litellm import LlmProviders from litellm import LlmProviders
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import ProviderField from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ProviderField,
TextCompletionResponse,
Usage,
)
from litellm.utils import ( from litellm.utils import (
Choices, Choices,
CustomStreamWrapper, CustomStreamWrapper,
Message, Message,
ModelResponse,
ProviderConfigManager, ProviderConfigManager,
TextCompletionResponse,
Usage,
convert_to_model_response_object, convert_to_model_response_object,
) )
from ...types.llms.openai import * from ...types.llms.openai import *
from ..base import BaseLLM from ..base import BaseLLM
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from .chat.gpt_transformation import OpenAIGPTConfig from .chat.gpt_transformation import OpenAIGPTConfig
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
@ -882,7 +889,7 @@ class OpenAIChatCompletion(BaseLLM):
self, self,
input: list, input: list,
data: dict, data: dict,
model_response: litellm.utils.EmbeddingResponse, model_response: EmbeddingResponse,
timeout: float, timeout: float,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@ -911,9 +918,7 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
returned_response: ( returned_response: EmbeddingResponse = convert_to_model_response_object(
litellm.EmbeddingResponse
) = convert_to_model_response_object(
response_object=stringified_response, response_object=stringified_response,
model_response_object=model_response, model_response_object=model_response,
response_type="embedding", response_type="embedding",
@ -953,14 +958,14 @@ class OpenAIChatCompletion(BaseLLM):
input: list, input: list,
timeout: float, timeout: float,
logging_obj, logging_obj,
model_response: litellm.utils.EmbeddingResponse, model_response: EmbeddingResponse,
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
client=None, client=None,
aembedding=None, aembedding=None,
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
) -> litellm.EmbeddingResponse: ) -> EmbeddingResponse:
super().embedding() super().embedding()
try: try:
model = model model = model
@ -1011,7 +1016,7 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=sync_embedding_response, original_response=sync_embedding_response,
) )
response: litellm.EmbeddingResponse = convert_to_model_response_object( response: EmbeddingResponse = convert_to_model_response_object(
response_object=sync_embedding_response.model_dump(), response_object=sync_embedding_response.model_dump(),
model_response_object=model_response, model_response_object=model_response,
_response_headers=headers, _response_headers=headers,
@ -1068,7 +1073,7 @@ class OpenAIChatCompletion(BaseLLM):
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=prompt,
api_key=api_key, api_key=api_key,
original_response=str(e), original_response=str(e),
) )
@ -1083,10 +1088,10 @@ class OpenAIChatCompletion(BaseLLM):
logging_obj: Any, logging_obj: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None, model_response: Optional[ImageResponse] = None,
client=None, client=None,
aimg_generation=None, aimg_generation=None,
) -> litellm.ImageResponse: ) -> ImageResponse:
data = {} data = {}
try: try:
model = model model = model

View file

@ -3,7 +3,7 @@ OpenAI-like chat completion transformation
""" """
import types import types
from typing import List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
@ -16,6 +16,13 @@ from litellm.types.utils import ModelResponse
from ....utils import _remove_additional_properties, _remove_strict_from_schema from ....utils import _remove_additional_properties, _remove_strict_from_schema
from ...openai.chat.gpt_transformation import OpenAIGPTConfig from ...openai.chat.gpt_transformation import OpenAIGPTConfig
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class OpenAILikeChatConfig(OpenAIGPTConfig): class OpenAILikeChatConfig(OpenAIGPTConfig):
def _get_openai_compatible_provider_info( def _get_openai_compatible_provider_info(
@ -64,7 +71,7 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
response: httpx.Response, response: httpx.Response,
model_response: ModelResponse, model_response: ModelResponse,
stream: bool, stream: bool,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore logging_obj: LiteLLMLoggingObj,
optional_params: dict, optional_params: dict,
api_key: Optional[str], api_key: Optional[str],
data: Union[dict, str], data: Union[dict, str],

View file

@ -11,8 +11,7 @@ from enum import Enum
from functools import partial from functools import partial
from typing import Any, Callable, List, Literal, Optional, Tuple, Union from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore import httpx
import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
@ -21,7 +20,7 @@ from litellm.llms.custom_httpx.http_handler import (
HTTPHandler, HTTPHandler,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.utils import EmbeddingResponse from litellm.types.utils import EmbeddingResponse
from ..common_utils import OpenAILikeBase, OpenAILikeError from ..common_utils import OpenAILikeBase, OpenAILikeError
@ -100,7 +99,7 @@ class OpenAILikeEmbeddingHandler(OpenAILikeBase):
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
optional_params: dict, optional_params: dict,
model_response: Optional[litellm.utils.EmbeddingResponse] = None, model_response: Optional[EmbeddingResponse] = None,
client=None, client=None,
aembedding=None, aembedding=None,
custom_endpoint: Optional[bool] = None, custom_endpoint: Optional[bool] = None,

View file

@ -10,6 +10,7 @@ from litellm.llms.base_llm.transformation import (
LiteLLMLoggingObj, LiteLLMLoggingObj,
) )
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ..common_utils import PetalsError from ..common_utils import PetalsError
@ -111,7 +112,7 @@ class PetalsConfig(BaseConfig):
self, self,
model: str, model: str,
raw_response: Response, raw_response: Response,
model_response: litellm.ModelResponse, model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
request_data: dict, request_data: dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
@ -120,7 +121,7 @@ class PetalsConfig(BaseConfig):
encoding: Any, encoding: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> litellm.ModelResponse: ) -> ModelResponse:
raise NotImplementedError( raise NotImplementedError(
"Petals transformation currently done in handler.py. [TODO] Move to the transformation.py" "Petals transformation currently done in handler.py. [TODO] Move to the transformation.py"
) )

View file

@ -27,6 +27,7 @@ from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.types.utils import LiteLLMLoggingBaseClass
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from ...base import BaseLLM from ...base import BaseLLM
@ -92,7 +93,7 @@ class PredibaseChatCompletion(BaseLLM):
response: Union[requests.Response, httpx.Response], response: Union[requests.Response, httpx.Response],
model_response: ModelResponse, model_response: ModelResponse,
stream: bool, stream: bool,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, logging_obj: LiteLLMLoggingBaseClass,
optional_params: dict, optional_params: dict,
api_key: str, api_key: str,
data: Union[dict, str], data: Union[dict, str],

View file

@ -13,10 +13,13 @@ from httpx._models import Headers, Response
import litellm import litellm
from litellm.litellm_core_utils.asyncify import asyncify from litellm.litellm_core_utils.asyncify import asyncify
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Usage from litellm.types.utils import ModelResponse, Usage
from ..common_utils import SagemakerError from ..common_utils import SagemakerError
@ -197,7 +200,7 @@ class SagemakerConfig(BaseConfig):
self, self,
model: str, model: str,
raw_response: Response, raw_response: Response,
model_response: litellm.ModelResponse, model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
request_data: dict, request_data: dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
@ -206,7 +209,7 @@ class SagemakerConfig(BaseConfig):
encoding: str, encoding: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> litellm.ModelResponse: ) -> ModelResponse:
completion_response = raw_response.json() completion_response = raw_response.json()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(

View file

@ -5,20 +5,20 @@ Why separate file? Make it easy to see how transformation works
""" """
import os import os
from typing import List, Literal, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union, cast
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.litellm_core_utils.prompt_templates.factory import ( from litellm.litellm_core_utils.prompt_templates.factory import (
convert_to_anthropic_image_obj, convert_to_anthropic_image_obj,
convert_to_gemini_tool_call_invoke, convert_to_gemini_tool_call_invoke,
convert_to_gemini_tool_call_result, convert_to_gemini_tool_call_result,
response_schema_prompt, response_schema_prompt,
) )
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.files import ( from litellm.types.files import (
get_file_mime_type_for_file_type, get_file_mime_type_for_file_type,
get_file_type_from_extension, get_file_type_from_extension,
@ -49,6 +49,13 @@ from ..common_utils import (
get_supports_system_message, get_supports_system_message,
) )
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
def _process_gemini_image(image_url: str) -> PartType: def _process_gemini_image(image_url: str) -> PartType:
""" """
@ -348,7 +355,7 @@ def sync_transform_request_body(
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
extra_headers: Optional[dict], extra_headers: Optional[dict],
optional_params: dict, optional_params: dict,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
litellm_params: dict, litellm_params: dict,
) -> RequestBody: ) -> RequestBody:

View file

@ -10,18 +10,17 @@ from litellm.llms.custom_httpx.http_handler import (
HTTPHandler, HTTPHandler,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
VertexLLM, from litellm.types.utils import ImageResponse
)
class VertexImageGeneration(VertexLLM): class VertexImageGeneration(VertexLLM):
def process_image_generation_response( def process_image_generation_response(
self, self,
json_response: Dict[str, Any], json_response: Dict[str, Any],
model_response: litellm.ImageResponse, model_response: ImageResponse,
model: Optional[str] = None, model: Optional[str] = None,
) -> litellm.ImageResponse: ) -> ImageResponse:
if "predictions" not in json_response: if "predictions" not in json_response:
raise litellm.InternalServerError( raise litellm.InternalServerError(
message=f"image generation response does not contain 'predictions', got {json_response}", message=f"image generation response does not contain 'predictions', got {json_response}",
@ -46,7 +45,7 @@ class VertexImageGeneration(VertexLLM):
vertex_project: Optional[str], vertex_project: Optional[str],
vertex_location: Optional[str], vertex_location: Optional[str],
vertex_credentials: Optional[str], vertex_credentials: Optional[str],
model_response: litellm.ImageResponse, model_response: ImageResponse,
logging_obj: Any, logging_obj: Any,
model: Optional[ model: Optional[
str str
@ -55,7 +54,7 @@ class VertexImageGeneration(VertexLLM):
optional_params: Optional[dict] = None, optional_params: Optional[dict] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
aimg_generation=False, aimg_generation=False,
) -> litellm.ImageResponse: ) -> ImageResponse:
if aimg_generation is True: if aimg_generation is True:
return self.aimage_generation( # type: ignore return self.aimage_generation( # type: ignore
prompt=prompt, prompt=prompt,

View file

@ -22,7 +22,7 @@ from litellm.types.llms.vertex_ai import (
MultimodalPredictions, MultimodalPredictions,
VertexMultimodalEmbeddingRequest, VertexMultimodalEmbeddingRequest,
) )
from litellm.types.utils import Embedding from litellm.types.utils import Embedding, EmbeddingResponse
from litellm.utils import is_base64_encoded from litellm.utils import is_base64_encoded
@ -39,7 +39,7 @@ class VertexMultimodalEmbedding(VertexLLM):
model: str, model: str,
input: Union[list, str], input: Union[list, str],
print_verbose, print_verbose,
model_response: litellm.EmbeddingResponse, model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"], custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict, optional_params: dict,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
@ -52,7 +52,7 @@ class VertexMultimodalEmbedding(VertexLLM):
aembedding=False, aembedding=False,
timeout=300, timeout=300,
client=None, client=None,
) -> litellm.EmbeddingResponse: ) -> EmbeddingResponse:
_auth_header, vertex_project = self._ensure_access_token( _auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, credentials=vertex_credentials,

View file

@ -15,12 +15,10 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.llms.vertex_ai.vertex_ai_non_gemini import ( from litellm.llms.vertex_ai.vertex_ai_non_gemini import VertexAIError
VertexAIError,
)
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.types.llms.vertex_ai import * from litellm.types.llms.vertex_ai import *
from litellm.utils import Usage from litellm.types.utils import EmbeddingResponse, Usage
from .transformation import VertexAITextEmbeddingConfig from .transformation import VertexAITextEmbeddingConfig
from .types import * from .types import *
@ -35,7 +33,7 @@ class VertexEmbedding(VertexBase):
model: str, model: str,
input: Union[list, str], input: Union[list, str],
print_verbose, print_verbose,
model_response: litellm.EmbeddingResponse, model_response: EmbeddingResponse,
optional_params: dict, optional_params: dict,
logging_obj: LiteLLMLoggingObject, logging_obj: LiteLLMLoggingObject,
custom_llm_provider: Literal[ custom_llm_provider: Literal[
@ -52,7 +50,7 @@ class VertexEmbedding(VertexBase):
vertex_credentials: Optional[str] = None, vertex_credentials: Optional[str] = None,
gemini_api_key: Optional[str] = None, gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
) -> litellm.EmbeddingResponse: ) -> EmbeddingResponse:
if aembedding is True: if aembedding is True:
return self.async_embedding( # type: ignore return self.async_embedding( # type: ignore
model=model, model=model,

View file

@ -4,7 +4,7 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
import litellm import litellm
from litellm.utils import Usage from litellm.types.utils import EmbeddingResponse, Usage
from .types import * from .types import *
@ -198,8 +198,8 @@ class VertexAITextEmbeddingConfig(BaseModel):
return text_embedding_input return text_embedding_input
def transform_vertex_response_to_openai( def transform_vertex_response_to_openai(
self, response: dict, model: str, model_response: litellm.EmbeddingResponse self, response: dict, model: str, model_response: EmbeddingResponse
) -> litellm.EmbeddingResponse: ) -> EmbeddingResponse:
""" """
Transforms a vertex embedding response to an openai response. Transforms a vertex embedding response to an openai response.
""" """
@ -234,8 +234,8 @@ class VertexAITextEmbeddingConfig(BaseModel):
return model_response return model_response
def _transform_vertex_response_to_openai_for_fine_tuned_models( def _transform_vertex_response_to_openai_for_fine_tuned_models(
self, response: dict, model: str, model_response: litellm.EmbeddingResponse self, response: dict, model: str, model_response: EmbeddingResponse
) -> litellm.EmbeddingResponse: ) -> EmbeddingResponse:
""" """
Transforms a vertex fine-tuned model embedding response to an openai response format. Transforms a vertex fine-tuned model embedding response to an openai response format.
""" """

View file

@ -24,6 +24,8 @@ import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.prompt_templates import factory as ptf
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
get_async_httpx_client, get_async_httpx_client,
@ -34,7 +36,6 @@ from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
from ...base import BaseLLM from ...base import BaseLLM
from litellm.litellm_core_utils.prompt_templates import factory as ptf
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
from .transformation import IBMWatsonXAIConfig from .transformation import IBMWatsonXAIConfig
@ -204,7 +205,7 @@ class IBMWatsonXAI(BaseLLM):
def process_stream_response( def process_stream_response(
stream_resp: Union[Iterator[str], AsyncIterator], stream_resp: Union[Iterator[str], AsyncIterator],
) -> litellm.CustomStreamWrapper: ) -> CustomStreamWrapper:
streamwrapper = litellm.CustomStreamWrapper( streamwrapper = litellm.CustomStreamWrapper(
stream_resp, stream_resp,
model=model, model=model,
@ -235,7 +236,7 @@ class IBMWatsonXAI(BaseLLM):
json_resp = resp.json() json_resp = resp.json()
return self._process_text_gen_response(json_resp, model_response) return self._process_text_gen_response(json_resp, model_response)
def handle_stream_request(request_params: dict) -> litellm.CustomStreamWrapper: def handle_stream_request(request_params: dict) -> CustomStreamWrapper:
# stream the response - generated chunks will be handled # stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
with self.request_manager.request( with self.request_manager.request(
@ -249,7 +250,7 @@ class IBMWatsonXAI(BaseLLM):
async def handle_stream_request_async( async def handle_stream_request_async(
request_params: dict, request_params: dict,
) -> litellm.CustomStreamWrapper: ) -> CustomStreamWrapper:
# stream the response - generated chunks will be handled # stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
async with self.request_manager.async_request( async with self.request_manager.async_request(
@ -321,14 +322,14 @@ class IBMWatsonXAI(BaseLLM):
self, self,
model: str, model: str,
input: Union[list, str], input: Union[list, str],
model_response: litellm.EmbeddingResponse, model_response: EmbeddingResponse,
api_key: Optional[str], api_key: Optional[str],
logging_obj: Any, logging_obj: Any,
optional_params: dict, optional_params: dict,
encoding=None, encoding=None,
print_verbose=None, print_verbose=None,
aembedding=None, aembedding=None,
) -> litellm.EmbeddingResponse: ) -> EmbeddingResponse:
""" """
Send a text embedding request to the IBM Watsonx.ai API. Send a text embedding request to the IBM Watsonx.ai API.
""" """

View file

@ -35,6 +35,7 @@ from litellm.proxy._types import (
) )
from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
from litellm.router import Router
from litellm.types.services import ServiceLoggerPayload, ServiceTypes from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from .auth_checks_organization import organization_role_based_access_check from .auth_checks_organization import organization_role_based_access_check
@ -61,7 +62,7 @@ def common_checks( # noqa: PLR0915
global_proxy_spend: Optional[float], global_proxy_spend: Optional[float],
general_settings: dict, general_settings: dict,
route: str, route: str,
llm_router: Optional[litellm.Router], llm_router: Optional[Router],
) -> bool: ) -> bool:
""" """
Common checks across jwt + key-based auth. Common checks across jwt + key-based auth.
@ -347,7 +348,7 @@ async def get_end_user_object(
def model_in_access_group( def model_in_access_group(
model: str, team_models: Optional[List[str]], llm_router: Optional[litellm.Router] model: str, team_models: Optional[List[str]], llm_router: Optional[Router]
) -> bool: ) -> bool:
from collections import defaultdict from collections import defaultdict

View file

@ -4,10 +4,11 @@ from typing import Any, Optional
import litellm import litellm
from litellm import CustomLLM, ImageObject, ImageResponse, completion, get_llm_provider from litellm import CustomLLM, ImageObject, ImageResponse, completion, get_llm_provider
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.utils import ModelResponse
class MyCustomLLM(CustomLLM): class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse: def completion(self, *args, **kwargs) -> ModelResponse:
return litellm.completion( return litellm.completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}], messages=[{"role": "user", "content": "Hello world"}],

View file

@ -45,6 +45,7 @@ from litellm.types.guardrails import (
BedrockTextContent, BedrockTextContent,
GuardrailEventHooks, GuardrailEventHooks,
) )
from litellm.types.utils import ModelResponse
GUARDRAIL_NAME = "bedrock" GUARDRAIL_NAME = "bedrock"
@ -70,7 +71,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
def convert_to_bedrock_format( def convert_to_bedrock_format(
self, self,
messages: Optional[List[Dict[str, str]]] = None, messages: Optional[List[Dict[str, str]]] = None,
response: Optional[Union[Any, litellm.ModelResponse]] = None, response: Optional[Union[Any, ModelResponse]] = None,
) -> BedrockRequest: ) -> BedrockRequest:
bedrock_request: BedrockRequest = BedrockRequest(source="INPUT") bedrock_request: BedrockRequest = BedrockRequest(source="INPUT")
bedrock_request_content: List[BedrockContentItem] = [] bedrock_request_content: List[BedrockContentItem] = []

View file

@ -20,8 +20,11 @@ import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.factory import prompt_injection_detection_default_pt from litellm.litellm_core_utils.prompt_templates.factory import (
prompt_injection_detection_default_pt,
)
from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth
from litellm.router import Router
from litellm.utils import get_formatted_prompt from litellm.utils import get_formatted_prompt
@ -32,7 +35,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None, prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None,
): ):
self.prompt_injection_params = prompt_injection_params self.prompt_injection_params = prompt_injection_params
self.llm_router: Optional[litellm.Router] = None self.llm_router: Optional[Router] = None
self.verbs = [ self.verbs = [
"Ignore", "Ignore",
@ -74,7 +77,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
if litellm.set_verbose is True: if litellm.set_verbose is True:
print(print_statement) # noqa print(print_statement) # noqa
def update_environment(self, router: Optional[litellm.Router] = None): def update_environment(self, router: Optional[Router] = None):
self.llm_router = router self.llm_router = router
if ( if (

View file

@ -16,6 +16,7 @@ from litellm.llms.anthropic.chat.handler import (
from litellm.llms.anthropic.chat.transformation import AnthropicConfig from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload
from litellm.types.utils import ModelResponse, TextCompletionResponse
if TYPE_CHECKING: if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging from ..success_handler import PassThroughEndpointLogging
@ -43,9 +44,7 @@ class AnthropicPassthroughLoggingHandler:
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
""" """
model = response_body.get("model", "") model = response_body.get("model", "")
litellm_model_response: ( litellm_model_response: ModelResponse = AnthropicConfig().transform_response(
litellm.ModelResponse
) = AnthropicConfig().transform_response(
raw_response=httpx_response, raw_response=httpx_response,
model_response=litellm.ModelResponse(), model_response=litellm.ModelResponse(),
model=model, model=model,
@ -89,9 +88,7 @@ class AnthropicPassthroughLoggingHandler:
@staticmethod @staticmethod
def _create_anthropic_response_logging_payload( def _create_anthropic_response_logging_payload(
litellm_model_response: Union[ litellm_model_response: Union[ModelResponse, TextCompletionResponse],
litellm.ModelResponse, litellm.TextCompletionResponse
],
model: str, model: str,
kwargs: dict, kwargs: dict,
start_time: datetime, start_time: datetime,
@ -204,7 +201,7 @@ class AnthropicPassthroughLoggingHandler:
all_chunks: List[str], all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj, litellm_logging_obj: LiteLLMLoggingObj,
model: str, model: str,
) -> Optional[Union[litellm.ModelResponse, litellm.TextCompletionResponse]]: ) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
""" """
Builds complete response from raw Anthropic chunks Builds complete response from raw Anthropic chunks

View file

@ -15,6 +15,12 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator as VertexModelResponseIterator, ModelResponseIterator as VertexModelResponseIterator,
) )
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
TextCompletionResponse,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging from ..success_handler import PassThroughEndpointLogging
@ -40,7 +46,7 @@ class VertexPassthroughLoggingHandler:
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route) model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
instance_of_vertex_llm = litellm.VertexGeminiConfig() instance_of_vertex_llm = litellm.VertexGeminiConfig()
litellm_model_response: litellm.ModelResponse = ( litellm_model_response: ModelResponse = (
instance_of_vertex_llm.transform_response( instance_of_vertex_llm.transform_response(
model=model, model=model,
messages=[ messages=[
@ -82,8 +88,8 @@ class VertexPassthroughLoggingHandler:
_json_response = httpx_response.json() _json_response = httpx_response.json()
litellm_prediction_response: Union[ litellm_prediction_response: Union[
litellm.ModelResponse, litellm.EmbeddingResponse, litellm.ImageResponse ModelResponse, EmbeddingResponse, ImageResponse
] = litellm.ModelResponse() ] = ModelResponse()
if vertex_image_generation_class.is_image_generation_response( if vertex_image_generation_class.is_image_generation_response(
_json_response _json_response
): ):
@ -176,7 +182,7 @@ class VertexPassthroughLoggingHandler:
all_chunks: List[str], all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj, litellm_logging_obj: LiteLLMLoggingObj,
model: str, model: str,
) -> Optional[Union[litellm.ModelResponse, litellm.TextCompletionResponse]]: ) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
vertex_iterator = VertexModelResponseIterator( vertex_iterator = VertexModelResponseIterator(
streaming_response=None, streaming_response=None,
sync_stream=False, sync_stream=False,
@ -212,9 +218,7 @@ class VertexPassthroughLoggingHandler:
@staticmethod @staticmethod
def _create_vertex_response_logging_payload_for_generate_content( def _create_vertex_response_logging_payload_for_generate_content(
litellm_model_response: Union[ litellm_model_response: Union[ModelResponse, TextCompletionResponse],
litellm.ModelResponse, litellm.TextCompletionResponse
],
model: str, model: str,
kwargs: dict, kwargs: dict,
start_time: datetime, start_time: datetime,

View file

@ -109,6 +109,7 @@ from litellm import (
CreateBatchRequest, CreateBatchRequest,
ListBatchRequest, ListBatchRequest,
RetrieveBatchRequest, RetrieveBatchRequest,
Router,
) )
from litellm._logging import verbose_proxy_logger, verbose_router_logger from litellm._logging import verbose_proxy_logger, verbose_router_logger
from litellm.caching.caching import DualCache, RedisCache from litellm.caching.caching import DualCache, RedisCache
@ -482,7 +483,7 @@ user_config_file_path: Optional[str] = None
local_logging = True # writes logs to a local api_log.json file for debugging local_logging = True # writes logs to a local api_log.json file for debugging
experimental = False experimental = False
#### GLOBAL VARIABLES #### #### GLOBAL VARIABLES ####
llm_router: Optional[litellm.Router] = None llm_router: Optional[Router] = None
llm_model_list: Optional[list] = None llm_model_list: Optional[list] = None
general_settings: dict = {} general_settings: dict = {}
callback_settings: dict = {} callback_settings: dict = {}
@ -2833,7 +2834,7 @@ class ProxyStartupEvent:
@classmethod @classmethod
def _initialize_startup_logging( def _initialize_startup_logging(
cls, cls,
llm_router: Optional[litellm.Router], llm_router: Optional[Router],
proxy_logging_obj: ProxyLogging, proxy_logging_obj: ProxyLogging,
redis_usage_cache: Optional[RedisCache], redis_usage_cache: Optional[RedisCache],
): ):

View file

@ -289,7 +289,7 @@ class ProxyLogging:
def startup_event( def startup_event(
self, self,
llm_router: Optional[litellm.Router], llm_router: Optional[Router],
redis_usage_cache: Optional[RedisCache], redis_usage_cache: Optional[RedisCache],
): ):
"""Initialize logging and alerting on proxy startup""" """Initialize logging and alerting on proxy startup"""
@ -359,7 +359,7 @@ class ProxyLogging:
if redis_cache is not None: if redis_cache is not None:
self.internal_usage_cache.dual_cache.redis_cache = redis_cache self.internal_usage_cache.dual_cache.redis_cache = redis_cache
def _init_litellm_callbacks(self, llm_router: Optional[litellm.Router] = None): def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore
litellm.callbacks.append(self.max_budget_limiter) # type: ignore litellm.callbacks.append(self.max_budget_limiter) # type: ignore
litellm.callbacks.append(self.cache_control_check) # type: ignore litellm.callbacks.append(self.cache_control_check) # type: ignore

View file

@ -145,6 +145,7 @@ from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
from litellm.utils import ( from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
EmbeddingResponse,
ModelResponse, ModelResponse,
_is_region_eu, _is_region_eu,
calculate_max_parallel_requests, calculate_max_parallel_requests,
@ -2071,7 +2072,7 @@ class Router:
input: Union[str, List], input: Union[str, List],
is_async: Optional[bool] = False, is_async: Optional[bool] = False,
**kwargs, **kwargs,
) -> litellm.EmbeddingResponse: ) -> EmbeddingResponse:
try: try:
kwargs["model"] = model kwargs["model"] = model
kwargs["input"] = input kwargs["input"] = input
@ -2146,7 +2147,7 @@ class Router:
input: Union[str, List], input: Union[str, List],
is_async: Optional[bool] = True, is_async: Optional[bool] = True,
**kwargs, **kwargs,
) -> litellm.EmbeddingResponse: ) -> EmbeddingResponse:
try: try:
kwargs["model"] = model kwargs["model"] = model
kwargs["input"] = input kwargs["input"] = input

View file

@ -1660,3 +1660,84 @@ class PersonalUIKeyGenerationConfig(KeyGenerationConfig):
class StandardKeyGenerationConfig(TypedDict, total=False): class StandardKeyGenerationConfig(TypedDict, total=False):
team_key_generation: TeamUIKeyGenerationConfig team_key_generation: TeamUIKeyGenerationConfig
personal_key_generation: PersonalUIKeyGenerationConfig personal_key_generation: PersonalUIKeyGenerationConfig
class LlmProviders(str, Enum):
OPENAI = "openai"
OPENAI_LIKE = "openai_like" # embedding only
JINA_AI = "jina_ai"
XAI = "xai"
CUSTOM_OPENAI = "custom_openai"
TEXT_COMPLETION_OPENAI = "text-completion-openai"
COHERE = "cohere"
COHERE_CHAT = "cohere_chat"
CLARIFAI = "clarifai"
ANTHROPIC = "anthropic"
ANTHROPIC_TEXT = "anthropic_text"
REPLICATE = "replicate"
HUGGINGFACE = "huggingface"
TOGETHER_AI = "together_ai"
OPENROUTER = "openrouter"
VERTEX_AI = "vertex_ai"
VERTEX_AI_BETA = "vertex_ai_beta"
GEMINI = "gemini"
AI21 = "ai21"
BASETEN = "baseten"
AZURE = "azure"
AZURE_TEXT = "azure_text"
AZURE_AI = "azure_ai"
SAGEMAKER = "sagemaker"
SAGEMAKER_CHAT = "sagemaker_chat"
BEDROCK = "bedrock"
VLLM = "vllm"
NLP_CLOUD = "nlp_cloud"
PETALS = "petals"
OOBABOOGA = "oobabooga"
OLLAMA = "ollama"
OLLAMA_CHAT = "ollama_chat"
DEEPINFRA = "deepinfra"
PERPLEXITY = "perplexity"
MISTRAL = "mistral"
GROQ = "groq"
NVIDIA_NIM = "nvidia_nim"
CEREBRAS = "cerebras"
AI21_CHAT = "ai21_chat"
VOLCENGINE = "volcengine"
CODESTRAL = "codestral"
TEXT_COMPLETION_CODESTRAL = "text-completion-codestral"
DEEPSEEK = "deepseek"
SAMBANOVA = "sambanova"
MARITALK = "maritalk"
VOYAGE = "voyage"
CLOUDFLARE = "cloudflare"
XINFERENCE = "xinference"
FIREWORKS_AI = "fireworks_ai"
FRIENDLIAI = "friendliai"
WATSONX = "watsonx"
WATSONX_TEXT = "watsonx_text"
TRITON = "triton"
PREDIBASE = "predibase"
DATABRICKS = "databricks"
EMPOWER = "empower"
GITHUB = "github"
CUSTOM = "custom"
LITELLM_PROXY = "litellm_proxy"
HOSTED_VLLM = "hosted_vllm"
LM_STUDIO = "lm_studio"
GALADRIEL = "galadriel"
class LiteLLMLoggingBaseClass:
"""
Base class for logging pre and post call
Meant to simplify type checking for logging obj.
"""
def pre_call(self, input, api_key, model=None, additional_args={}):
pass
def post_call(
self, original_response, input=None, api_key=None, additional_args={}
):
pass

View file

@ -126,6 +126,7 @@ from litellm.types.utils import (
EmbeddingResponse, EmbeddingResponse,
Function, Function,
ImageResponse, ImageResponse,
LlmProviders,
Message, Message,
ModelInfo, ModelInfo,
ModelResponse, ModelResponse,
@ -147,6 +148,7 @@ claude_json_str = json.dumps(json_data)
import importlib.metadata import importlib.metadata
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict, Dict,
@ -163,6 +165,8 @@ from typing import (
from openai import OpenAIError as OriginalError from openai import OpenAIError as OriginalError
from litellm.llms.base_llm.transformation import BaseConfig
from ._logging import verbose_logger from ._logging import verbose_logger
from .caching.caching import ( from .caching.caching import (
Cache, Cache,
@ -235,7 +239,6 @@ last_fetched_at = None
last_fetched_at_keys = None last_fetched_at_keys = None
######## Model Response ######################### ######## Model Response #########################
# All liteLLM Model responses will be in this format, Follows the OpenAI Format # All liteLLM Model responses will be in this format, Follows the OpenAI Format
# https://docs.litellm.ai/docs/completion/output # https://docs.litellm.ai/docs/completion/output
# { # {
@ -6205,13 +6208,10 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
return messages return messages
from litellm.llms.base_llm.transformation import BaseConfig
class ProviderConfigManager: class ProviderConfigManager:
@staticmethod @staticmethod
def get_provider_chat_config( # noqa: PLR0915 def get_provider_chat_config( # noqa: PLR0915
model: str, provider: litellm.LlmProviders model: str, provider: LlmProviders
) -> BaseConfig: ) -> BaseConfig:
""" """
Returns the provider config for a given provider. Returns the provider config for a given provider.

View file

@ -173,8 +173,8 @@ def main():
"list_organization", "list_organization",
"user_update", "user_update",
] ]
directory = "../../litellm/proxy/management_endpoints" # LOCAL # directory = "../../litellm/proxy/management_endpoints" # LOCAL
# directory = "./litellm/proxy/management_endpoints" directory = "./litellm/proxy/management_endpoints"
# Convert function names to set for faster lookup # Convert function names to set for faster lookup
target_functions = set(function_names) target_functions = set(function_names)

View file

@ -0,0 +1,162 @@
import os
import ast
import sys
from typing import List, Tuple, Optional
def find_litellm_type_hints(directory: str) -> List[Tuple[str, int, str]]:
"""
Recursively search for Python files in the given directory
and find type hints containing 'litellm.'.
Args:
directory (str): The root directory to search for Python files
Returns:
List of tuples containing (file_path, line_number, type_hint)
"""
litellm_type_hints = []
def is_litellm_type_hint(node):
"""
Recursively check if a type annotation contains 'litellm.'
Handles more complex type hints like:
- Optional[litellm.Type]
- Union[litellm.Type1, litellm.Type2]
- Nested type hints
"""
try:
# Convert node to string representation
type_str = ast.unparse(node)
# Direct check for litellm in type string
if "litellm." in type_str:
return True
# Handle more complex type hints
if isinstance(node, ast.Subscript):
# Check Union or Optional types
if isinstance(node.value, ast.Name) and node.value.id in [
"Union",
"Optional",
]:
# Check each element in the Union/Optional type
if isinstance(node.slice, ast.Tuple):
return any(is_litellm_type_hint(elt) for elt in node.slice.elts)
else:
return is_litellm_type_hint(node.slice)
# Recursive check for subscripted types
return is_litellm_type_hint(node.value) or is_litellm_type_hint(
node.slice
)
# Recursive check for attribute types
if isinstance(node, ast.Attribute):
return "litellm." in ast.unparse(node)
# Recursive check for name types
if isinstance(node, ast.Name):
return "litellm" in node.id
return False
except Exception:
# Fallback to string checking if parsing fails
try:
return "litellm." in ast.unparse(node)
except:
return False
def scan_file(file_path: str):
"""
Scan a single Python file for LiteLLM type hints
"""
try:
# Use utf-8-sig to handle files with BOM, ignore errors
with open(file_path, "r", encoding="utf-8-sig", errors="ignore") as file:
tree = ast.parse(file.read())
for node in ast.walk(tree):
# Check type annotations in variable annotations
if isinstance(node, ast.AnnAssign) and node.annotation:
if is_litellm_type_hint(node.annotation):
litellm_type_hints.append(
(file_path, node.lineno, ast.unparse(node.annotation))
)
# Check type hints in function arguments
elif isinstance(node, ast.FunctionDef):
for arg in node.args.args:
if arg.annotation and is_litellm_type_hint(arg.annotation):
litellm_type_hints.append(
(file_path, arg.lineno, ast.unparse(arg.annotation))
)
# Check return type annotation
if node.returns and is_litellm_type_hint(node.returns):
litellm_type_hints.append(
(file_path, node.lineno, ast.unparse(node.returns))
)
except SyntaxError as e:
print(f"Syntax error in {file_path}: {e}", file=sys.stderr)
except Exception as e:
print(f"Error processing {file_path}: {e}", file=sys.stderr)
# Recursively walk through directory
for root, dirs, files in os.walk(directory):
# Remove virtual environment and cache directories from search
dirs[:] = [
d
for d in dirs
if not any(
venv in d
for venv in [
"venv",
"env",
"myenv",
".venv",
"__pycache__",
".pytest_cache",
]
)
]
for file in files:
if file.endswith(".py"):
full_path = os.path.join(root, file)
# Skip files in virtual environment or cache directories
if not any(
venv in full_path
for venv in [
"venv",
"env",
"myenv",
".venv",
"__pycache__",
".pytest_cache",
]
):
scan_file(full_path)
return litellm_type_hints
def main():
# Get directory from command line argument or use current directory
directory = "./litellm/"
# Find LiteLLM type hints
results = find_litellm_type_hints(directory)
# Print results
if results:
print("LiteLLM Type Hints Found:")
for file_path, line_num, type_hint in results:
print(f"{file_path}:{line_num} - {type_hint}")
else:
print("No LiteLLM type hints found.")
if __name__ == "__main__":
main()