""" OpenAI-like chat completion handler For handling OpenAI-like chat completions, like IBM WatsonX, etc. """ import copy import json import os import time import types from enum import Enum from functools import partial from typing import Any, Callable, List, Literal, Optional, Tuple, Union import httpx import litellm from litellm import LlmProviders from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, get_async_httpx_client, ) from litellm.llms.databricks.streaming_utils import ModelResponseIterator from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig from litellm.llms.openai.openai import OpenAIConfig from litellm.types.utils import CustomStreamingDecoder, ModelResponse from litellm.utils import ( Choices, CustomStreamWrapper, EmbeddingResponse, Message, ProviderConfigManager, TextCompletionResponse, Usage, convert_to_model_response_object, ) from ..common_utils import OpenAILikeBase, OpenAILikeError from .transformation import OpenAILikeChatConfig async def make_call( client: Optional[AsyncHTTPHandler], api_base: str, headers: dict, data: str, model: str, messages: list, logging_obj, streaming_decoder: Optional[CustomStreamingDecoder] = None, fake_stream: bool = False, ): if client is None: client = litellm.module_level_aclient response = await client.post( api_base, headers=headers, data=data, stream=not fake_stream ) if streaming_decoder is not None: completion_stream: Any = streaming_decoder.aiter_bytes( response.aiter_bytes(chunk_size=1024) ) elif fake_stream: model_response = ModelResponse(**response.json()) completion_stream = MockResponseIterator(model_response=model_response) else: completion_stream = ModelResponseIterator( streaming_response=response.aiter_lines(), sync_stream=False ) # LOGGING logging_obj.post_call( input=messages, api_key="", original_response=completion_stream, # Pass the completion stream for logging additional_args={"complete_input_dict": data}, ) return completion_stream def make_sync_call( client: Optional[HTTPHandler], api_base: str, headers: dict, data: str, model: str, messages: list, logging_obj, streaming_decoder: Optional[CustomStreamingDecoder] = None, fake_stream: bool = False, ): if client is None: client = litellm.module_level_client # Create a new client if none provided response = client.post(api_base, headers=headers, data=data, stream=not fake_stream) if response.status_code != 200: raise OpenAILikeError(status_code=response.status_code, message=response.read()) if streaming_decoder is not None: completion_stream = streaming_decoder.iter_bytes( response.iter_bytes(chunk_size=1024) ) elif fake_stream: model_response = ModelResponse(**response.json()) completion_stream = MockResponseIterator(model_response=model_response) else: completion_stream = ModelResponseIterator( streaming_response=response.iter_lines(), sync_stream=True ) # LOGGING logging_obj.post_call( input=messages, api_key="", original_response="first stream response received", additional_args={"complete_input_dict": data}, ) return completion_stream class OpenAILikeChatHandler(OpenAILikeBase): def __init__(self, **kwargs): super().__init__(**kwargs) async def acompletion_stream_function( self, model: str, messages: list, custom_llm_provider: str, api_base: str, custom_prompt_dict: dict, model_response: ModelResponse, print_verbose: Callable, encoding, api_key, logging_obj, stream, data: dict, optional_params=None, litellm_params=None, logger_fn=None, headers={}, client: Optional[AsyncHTTPHandler] = None, streaming_decoder: Optional[CustomStreamingDecoder] = None, fake_stream: bool = False, ) -> CustomStreamWrapper: data["stream"] = True completion_stream = await make_call( client=client, api_base=api_base, headers=headers, data=json.dumps(data), model=model, messages=messages, logging_obj=logging_obj, streaming_decoder=streaming_decoder, ) streamwrapper = CustomStreamWrapper( completion_stream=completion_stream, model=model, custom_llm_provider=custom_llm_provider, logging_obj=logging_obj, ) return streamwrapper async def acompletion_function( self, model: str, messages: list, api_base: str, custom_prompt_dict: dict, model_response: ModelResponse, custom_llm_provider: str, print_verbose: Callable, client: Optional[AsyncHTTPHandler], encoding, api_key, logging_obj, stream, data: dict, base_model: Optional[str], optional_params: dict, litellm_params=None, logger_fn=None, headers={}, timeout: Optional[Union[float, httpx.Timeout]] = None, json_mode: bool = False, ) -> ModelResponse: if timeout is None: timeout = httpx.Timeout(timeout=600.0, connect=5.0) if client is None: client = litellm.module_level_aclient try: response = await client.post( api_base, headers=headers, data=json.dumps(data), timeout=timeout ) response.raise_for_status() except httpx.HTTPStatusError as e: raise OpenAILikeError( status_code=e.response.status_code, message=e.response.text, ) except httpx.TimeoutException: raise OpenAILikeError(status_code=408, message="Timeout error occurred.") except Exception as e: raise OpenAILikeError(status_code=500, message=str(e)) return OpenAILikeChatConfig._transform_response( model=model, response=response, model_response=model_response, stream=stream, logging_obj=logging_obj, optional_params=optional_params, api_key=api_key, data=data, messages=messages, print_verbose=print_verbose, encoding=encoding, json_mode=json_mode, custom_llm_provider=custom_llm_provider, base_model=base_model, ) def completion( self, *, model: str, messages: list, api_base: str, custom_llm_provider: str, custom_prompt_dict: dict, model_response: ModelResponse, print_verbose: Callable, encoding, api_key: Optional[str], logging_obj, optional_params: dict, acompletion=None, litellm_params=None, logger_fn=None, headers: Optional[dict] = None, timeout: Optional[Union[float, httpx.Timeout]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, custom_endpoint: Optional[bool] = None, streaming_decoder: Optional[ CustomStreamingDecoder ] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker fake_stream: bool = False, ): custom_endpoint = custom_endpoint or optional_params.pop( "custom_endpoint", None ) base_model: Optional[str] = optional_params.pop("base_model", None) api_base, headers = self._validate_environment( api_base=api_base, api_key=api_key, endpoint_type="chat_completions", custom_endpoint=custom_endpoint, headers=headers, ) stream: bool = optional_params.pop("stream", None) or False extra_body = optional_params.pop("extra_body", {}) json_mode = optional_params.pop("json_mode", None) optional_params.pop("max_retries", None) if not fake_stream: optional_params["stream"] = stream if messages is not None and custom_llm_provider is not None: provider_config = ProviderConfigManager.get_provider_chat_config( model=model, provider=LlmProviders(custom_llm_provider) ) if isinstance(provider_config, OpenAIGPTConfig) or isinstance( provider_config, OpenAIConfig ): messages = provider_config._transform_messages(messages) data = { "model": model, "messages": messages, **optional_params, **extra_body, } ## LOGGING logging_obj.pre_call( input=messages, api_key=api_key, additional_args={ "complete_input_dict": data, "api_base": api_base, "headers": headers, }, ) if acompletion is True: if client is None or not isinstance(client, AsyncHTTPHandler): client = None if ( stream is True ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) data["stream"] = stream return self.acompletion_stream_function( model=model, messages=messages, data=data, api_base=api_base, custom_prompt_dict=custom_prompt_dict, model_response=model_response, print_verbose=print_verbose, encoding=encoding, api_key=api_key, logging_obj=logging_obj, optional_params=optional_params, stream=stream, litellm_params=litellm_params, logger_fn=logger_fn, headers=headers, client=client, custom_llm_provider=custom_llm_provider, streaming_decoder=streaming_decoder, fake_stream=fake_stream, ) else: return self.acompletion_function( model=model, messages=messages, data=data, api_base=api_base, custom_prompt_dict=custom_prompt_dict, custom_llm_provider=custom_llm_provider, model_response=model_response, print_verbose=print_verbose, encoding=encoding, api_key=api_key, logging_obj=logging_obj, optional_params=optional_params, stream=stream, litellm_params=litellm_params, logger_fn=logger_fn, headers=headers, timeout=timeout, base_model=base_model, client=client, ) else: ## COMPLETION CALL if stream is True: completion_stream = make_sync_call( client=( client if client is not None and isinstance(client, HTTPHandler) else None ), api_base=api_base, headers=headers, data=json.dumps(data), model=model, messages=messages, logging_obj=logging_obj, streaming_decoder=streaming_decoder, fake_stream=fake_stream, ) # completion_stream.__iter__() return CustomStreamWrapper( completion_stream=completion_stream, model=model, custom_llm_provider=custom_llm_provider, logging_obj=logging_obj, ) else: if client is None or not isinstance(client, HTTPHandler): client = HTTPHandler(timeout=timeout) # type: ignore try: response = client.post( api_base, headers=headers, data=json.dumps(data) ) response.raise_for_status() except httpx.HTTPStatusError as e: raise OpenAILikeError( status_code=e.response.status_code, message=e.response.text, ) except httpx.TimeoutException: raise OpenAILikeError( status_code=408, message="Timeout error occurred." ) except Exception as e: raise OpenAILikeError(status_code=500, message=str(e)) return OpenAILikeChatConfig._transform_response( model=model, response=response, model_response=model_response, stream=stream, logging_obj=logging_obj, optional_params=optional_params, api_key=api_key, data=data, messages=messages, print_verbose=print_verbose, encoding=encoding, json_mode=json_mode, custom_llm_provider=custom_llm_provider, base_model=base_model, )