diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 4f5d4f263..7085e58b3 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -307,7 +307,7 @@ class BedrockLLM(BaseLLM): try: if provider == "cohere": - model_response.choices[0].message.content = completion_response["text"] # type: ignore + outputText = completion_response["text"] # type: ignore elif provider == "anthropic": if model.startswith("anthropic.claude-3"): json_schemas: dict = {} @@ -427,6 +427,15 @@ class BedrockLLM(BaseLLM): outputText = ( completion_response.get("completions")[0].get("data").get("text") ) + elif provider == "meta": + outputText = completion_response["generation"] + elif provider == "mistral": + outputText = completion_response["outputs"][0]["text"] + model_response["finish_reason"] = completion_response["outputs"][0][ + "stop_reason" + ] + else: # amazon titan + outputText = completion_response.get("results")[0].get("outputText") except Exception as e: raise BedrockError( message="Error processing={}, Received error={}".format( @@ -691,6 +700,40 @@ class BedrockLLM(BaseLLM): inference_params[k] = v data = json.dumps({"prompt": prompt, **inference_params}) + elif provider == "mistral": + ## LOAD CONFIG + config = litellm.AmazonMistralConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + + data = json.dumps({"prompt": prompt, **inference_params}) + elif provider == "amazon": # amazon titan + ## LOAD CONFIG + config = litellm.AmazonTitanConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + + data = json.dumps( + { + "inputText": prompt, + "textGenerationConfig": inference_params, + } + ) + elif provider == "meta": + ## LOAD CONFIG + config = litellm.AmazonLlamaConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + data = json.dumps({"prompt": prompt, **inference_params}) else: raise Exception("UNSUPPORTED PROVIDER") diff --git a/litellm/main.py b/litellm/main.py index 769b5964a..0fad87d6d 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -326,10 +326,7 @@ async def acompletion( or custom_llm_provider == "sagemaker" or custom_llm_provider == "anthropic" or custom_llm_provider == "predibase" - or ( - custom_llm_provider == "bedrock" - and ("cohere" in model or "anthropic" in model or "ai21" in model) - ) + or custom_llm_provider == "bedrock" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -1982,59 +1979,21 @@ def completion( # boto3 reads keys from .env custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - if "cohere" in model or "anthropic" in model or "ai21" in model: - response = bedrock_chat_completion.completion( - model=model, - messages=messages, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - extra_headers=extra_headers, - timeout=timeout, - acompletion=acompletion, - ) - else: - response = bedrock.completion( - model=model, - messages=messages, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - extra_headers=extra_headers, - timeout=timeout, - ) - - if ( - "stream" in optional_params - and optional_params["stream"] == True - and not isinstance(response, CustomStreamWrapper) - ): - # don't try to access stream object, - if "ai21" in model: - response = CustomStreamWrapper( - response, - model, - custom_llm_provider="bedrock", - logging_obj=logging, - ) - else: - response = CustomStreamWrapper( - iter(response), - model, - custom_llm_provider="bedrock", - logging_obj=logging, - ) - + response = bedrock_chat_completion.completion( + model=model, + messages=messages, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + acompletion=acompletion, + ) if optional_params.get("stream", False): ## LOGGING logging.post_call( diff --git a/litellm/tests/log.txt b/litellm/tests/log.txt index 4d3027355..b3c9d4a09 100644 --- a/litellm/tests/log.txt +++ b/litellm/tests/log.txt @@ -1,4045 +1,1061 @@ ============================= test session starts ============================== -platform darwin -- Python 3.11.9, pytest-7.3.1, pluggy-1.3.0 +platform darwin -- Python 3.11.9, pytest-7.3.1, pluggy-1.3.0 -- /opt/homebrew/opt/python@3.11/bin/python3.11 +cachedir: .pytest_cache rootdir: /Users/krrishdholakia/Documents/litellm/litellm/tests plugins: timeout-2.2.0, asyncio-0.23.2, anyio-3.7.1, xdist-3.3.1 asyncio: mode=Mode.STRICT -collected 2 items +collecting ... collected 2 items -test_streaming.py .Token Counter - using hugging face token counter, for model=llama-3-8b-instruct -Looking up model=llama-3-8b-instruct in model_cost_map -F [100%] +test_streaming.py::test_bedrock_httpx_streaming[bedrock/amazon.titan-tg1-large-False] FAILED [ 50%] =================================== FAILURES =================================== -__________________ test_completion_predibase_streaming[True] ___________________ +______ test_bedrock_httpx_streaming[bedrock/amazon.titan-tg1-large-False] ______ -model = 'llama-3-8b-instruct' -messages = [{'content': 'What is the meaning of life?', 'role': 'user'}] -timeout = 600.0, temperature = None, top_p = None, n = None, stream = True -stream_options = None, stop = None, max_tokens = None, presence_penalty = None -frequency_penalty = None, logit_bias = None, user = None, response_format = None -seed = None, tools = None, tool_choice = None, logprobs = None -top_logprobs = None, deployment_id = None, extra_headers = None -functions = None, function_call = None, base_url = None, api_version = None -api_key = 'pb_Qg9YbQo7UqqHdu0ozxN_aw', model_list = None -kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_id': 'cf0ea464-1b45-4473-8e55-6bf6809df7a7', 'litellm_logging_obj': , 'tenant_id': 'c4768f95'} -args = {'acompletion': False, 'api_base': None, 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'api_version': None, ...} -api_base = None, mock_response = None, force_timeout = 600, logger_fn = None -verbose = False, custom_llm_provider = 'predibase' +self = +chunk = {'finish_reason': '', 'is_finished': False, 'text': '\nHello, I am an AI model developed by Amazon Titan Foundation Mo...able of understanding and generating human-like text. My development has been focused on continuously improving my pe'} - @client - def completion( - model: str, - # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create - messages: List = [], - timeout: Optional[Union[float, str, httpx.Timeout]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - n: Optional[int] = None, - stream: Optional[bool] = None, - stream_options: Optional[dict] = None, - stop=None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[dict] = None, - user: Optional[str] = None, - # openai v1.0+ new params - response_format: Optional[dict] = None, - seed: Optional[int] = None, - tools: Optional[List] = None, - tool_choice: Optional[str] = None, - logprobs: Optional[bool] = None, - top_logprobs: Optional[int] = None, - deployment_id=None, - extra_headers: Optional[dict] = None, - # soon to be deprecated params by OpenAI - functions: Optional[List] = None, - function_call: Optional[str] = None, - # set api_base, api_version, api_key - base_url: Optional[str] = None, - api_version: Optional[str] = None, - api_key: Optional[str] = None, - model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. - # Optional liteLLM function params - **kwargs, - ) -> Union[ModelResponse, CustomStreamWrapper]: - """ - Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) - Parameters: - model (str): The name of the language model to use for text completion. see all supported LLMs: https://docs.litellm.ai/docs/providers/ - messages (List): A list of message objects representing the conversation context (default is an empty list). - - OPTIONAL PARAMS - functions (List, optional): A list of functions to apply to the conversation messages (default is an empty list). - function_call (str, optional): The name of the function to call within the conversation (default is an empty string). - temperature (float, optional): The temperature parameter for controlling the randomness of the output (default is 1.0). - top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). - n (int, optional): The number of completions to generate (default is 1). - stream (bool, optional): If True, return a streaming response (default is False). - stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true. - stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. - max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). - presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. - frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. - logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. - user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. - logprobs (bool, optional): Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message - top_logprobs (int, optional): An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. - metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. - api_base (str, optional): Base URL for the API (default is None). - api_version (str, optional): API version (default is None). - api_key (str, optional): API key (default is None). - model_list (list, optional): List of api base, version, keys - extra_headers (dict, optional): Additional headers to include in the request. - - LITELLM Specific Params - mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None). - custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock" - max_retries (int, optional): The number of retries to attempt (default is 0). - Returns: - ModelResponse: A response object containing the generated completion and associated metadata. - - Note: - - This function is used to perform completions() using the specified language model. - - It supports various optional parameters for customizing the completion behavior. - - If 'mock_response' is provided, a mock completion response is returned for testing or debugging. - """ - ######### unpacking kwargs ##################### - args = locals() - api_base = kwargs.get("api_base", None) - mock_response = kwargs.get("mock_response", None) - force_timeout = kwargs.get("force_timeout", 600) ## deprecated - logger_fn = kwargs.get("logger_fn", None) - verbose = kwargs.get("verbose", False) - custom_llm_provider = kwargs.get("custom_llm_provider", None) - litellm_logging_obj = kwargs.get("litellm_logging_obj", None) - id = kwargs.get("id", None) - metadata = kwargs.get("metadata", None) - model_info = kwargs.get("model_info", None) - proxy_server_request = kwargs.get("proxy_server_request", None) - fallbacks = kwargs.get("fallbacks", None) - headers = kwargs.get("headers", None) - num_retries = kwargs.get("num_retries", None) ## deprecated - max_retries = kwargs.get("max_retries", None) - context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) - organization = kwargs.get("organization", None) - ### CUSTOM MODEL COST ### - input_cost_per_token = kwargs.get("input_cost_per_token", None) - output_cost_per_token = kwargs.get("output_cost_per_token", None) - input_cost_per_second = kwargs.get("input_cost_per_second", None) - output_cost_per_second = kwargs.get("output_cost_per_second", None) - ### CUSTOM PROMPT TEMPLATE ### - initial_prompt_value = kwargs.get("initial_prompt_value", None) - roles = kwargs.get("roles", None) - final_prompt_value = kwargs.get("final_prompt_value", None) - bos_token = kwargs.get("bos_token", None) - eos_token = kwargs.get("eos_token", None) - preset_cache_key = kwargs.get("preset_cache_key", None) - hf_model_name = kwargs.get("hf_model_name", None) - supports_system_message = kwargs.get("supports_system_message", None) - ### TEXT COMPLETION CALLS ### - text_completion = kwargs.get("text_completion", False) - atext_completion = kwargs.get("atext_completion", False) - ### ASYNC CALLS ### - acompletion = kwargs.get("acompletion", False) - client = kwargs.get("client", None) - ### Admin Controls ### - no_log = kwargs.get("no-log", False) - ######## end of unpacking kwargs ########### - openai_params = [ - "functions", - "function_call", - "temperature", - "temperature", - "top_p", - "n", - "stream", - "stream_options", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "request_timeout", - "api_base", - "api_version", - "api_key", - "deployment_id", - "organization", - "base_url", - "default_headers", - "timeout", - "response_format", - "seed", - "tools", - "tool_choice", - "max_retries", - "logprobs", - "top_logprobs", - "extra_headers", - ] - litellm_params = [ - "metadata", - "acompletion", - "atext_completion", - "text_completion", - "caching", - "mock_response", - "api_key", - "api_version", - "api_base", - "force_timeout", - "logger_fn", - "verbose", - "custom_llm_provider", - "litellm_logging_obj", - "litellm_call_id", - "use_client", - "id", - "fallbacks", - "azure", - "headers", - "model_list", - "num_retries", - "context_window_fallback_dict", - "retry_policy", - "roles", - "final_prompt_value", - "bos_token", - "eos_token", - "request_timeout", - "complete_response", - "self", - "client", - "rpm", - "tpm", - "max_parallel_requests", - "input_cost_per_token", - "output_cost_per_token", - "input_cost_per_second", - "output_cost_per_second", - "hf_model_name", - "model_info", - "proxy_server_request", - "preset_cache_key", - "caching_groups", - "ttl", - "cache", - "no-log", - "base_model", - "stream_timeout", - "supports_system_message", - "region_name", - "allowed_model_region", - ] - default_params = openai_params + litellm_params - non_default_params = { - k: v for k, v in kwargs.items() if k not in default_params - } # model-specific params - pass them straight to the model/provider - - ### TIMEOUT LOGIC ### - timeout = timeout or kwargs.get("request_timeout", 600) or 600 - # set timeout for 10 minutes by default - - if ( - timeout is not None - and isinstance(timeout, httpx.Timeout) - and supports_httpx_timeout(custom_llm_provider) == False - ): - read_timeout = timeout.read or 600 - timeout = read_timeout # default 10 min timeout - elif timeout is not None and not isinstance(timeout, httpx.Timeout): - timeout = float(timeout) # type: ignore - + def chunk_creator(self, chunk): + model_response = self.model_response_creator() + response_obj = {} try: - if base_url is not None: - api_base = base_url - if max_retries is not None: # openai allows openai.OpenAI(max_retries=3) - num_retries = max_retries - logging = litellm_logging_obj - fallbacks = fallbacks or litellm.model_fallbacks - if fallbacks is not None: - return completion_with_fallbacks(**args) - if model_list is not None: - deployments = [ - m["litellm_params"] for m in model_list if m["model_name"] == model - ] - return batch_completion_models(deployments=deployments, **args) - if litellm.model_alias_map and model in litellm.model_alias_map: - model = litellm.model_alias_map[ - model - ] # update the model to the actual value if an alias has been passed in - model_response = ModelResponse() - setattr(model_response, "usage", litellm.Usage()) - if ( - kwargs.get("azure", False) == True - ): # don't remove flag check, to remain backwards compatible for repos like Codium - custom_llm_provider = "azure" - if deployment_id != None: # azure llms - model = deployment_id - custom_llm_provider = "azure" - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( - model=model, - custom_llm_provider=custom_llm_provider, - api_base=api_base, - api_key=api_key, - ) - if model_response is not None and hasattr(model_response, "_hidden_params"): - model_response._hidden_params["custom_llm_provider"] = custom_llm_provider - model_response._hidden_params["region_name"] = kwargs.get( - "aws_region_name", None - ) # support region-based pricing for bedrock - - ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### - if input_cost_per_token is not None and output_cost_per_token is not None: - print_verbose(f"Registering model={model} in model cost map") - litellm.register_model( - { - f"{custom_llm_provider}/{model}": { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider, - }, - model: { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider, - }, - } - ) + # return this for all models + completion_obj = {"content": ""} + if self.custom_llm_provider and self.custom_llm_provider == "anthropic": + response_obj = self.handle_anthropic_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] elif ( - input_cost_per_second is not None - ): # time based pricing just needs cost in place - output_cost_per_second = output_cost_per_second - litellm.register_model( - { - f"{custom_llm_provider}/{model}": { - "input_cost_per_second": input_cost_per_second, - "output_cost_per_second": output_cost_per_second, - "litellm_provider": custom_llm_provider, - }, - model: { - "input_cost_per_second": input_cost_per_second, - "output_cost_per_second": output_cost_per_second, - "litellm_provider": custom_llm_provider, - }, - } - ) - ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### - custom_prompt_dict = {} # type: ignore - if ( - initial_prompt_value - or roles - or final_prompt_value - or bos_token - or eos_token + self.custom_llm_provider + and self.custom_llm_provider == "anthropic_text" ): - custom_prompt_dict = {model: {}} - if initial_prompt_value: - custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value - if roles: - custom_prompt_dict[model]["roles"] = roles - if final_prompt_value: - custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value - if bos_token: - custom_prompt_dict[model]["bos_token"] = bos_token - if eos_token: - custom_prompt_dict[model]["eos_token"] = eos_token - - if ( - supports_system_message is not None - and isinstance(supports_system_message, bool) - and supports_system_message == False - ): - messages = map_system_message_pt(messages=messages) - model_api_key = get_api_key( - llm_provider=custom_llm_provider, dynamic_api_key=api_key - ) # get the api key from the environment if required for the model - - if dynamic_api_key is not None: - api_key = dynamic_api_key - # check if user passed in any of the OpenAI optional params - optional_params = get_optional_params( - functions=functions, - function_call=function_call, - temperature=temperature, - top_p=top_p, - n=n, - stream=stream, - stream_options=stream_options, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - user=user, - # params to identify the model - model=model, - custom_llm_provider=custom_llm_provider, - response_format=response_format, - seed=seed, - tools=tools, - tool_choice=tool_choice, - max_retries=max_retries, - logprobs=logprobs, - top_logprobs=top_logprobs, - extra_headers=extra_headers, - **non_default_params, - ) - - if litellm.add_function_to_prompt and optional_params.get( - "functions_unsupported_model", None - ): # if user opts to add it to prompt, when API doesn't support function calling - functions_unsupported_model = optional_params.pop( - "functions_unsupported_model" - ) - messages = function_call_prompt( - messages=messages, functions=functions_unsupported_model - ) - - # For logging - save the values of the litellm-specific params passed in - litellm_params = get_litellm_params( - acompletion=acompletion, - api_key=api_key, - force_timeout=force_timeout, - logger_fn=logger_fn, - verbose=verbose, - custom_llm_provider=custom_llm_provider, - api_base=api_base, - litellm_call_id=kwargs.get("litellm_call_id", None), - model_alias_map=litellm.model_alias_map, - completion_call_id=id, - metadata=metadata, - model_info=model_info, - proxy_server_request=proxy_server_request, - preset_cache_key=preset_cache_key, - no_log=no_log, - ) - logging.update_environment_variables( - model=model, - user=user, - optional_params=optional_params, - litellm_params=litellm_params, - ) - if mock_response: - return mock_completion( - model, - messages, - stream=stream, - mock_response=mock_response, - logging=logging, - acompletion=acompletion, - ) - if custom_llm_provider == "azure": - # azure configs - api_type = get_secret("AZURE_API_TYPE") or "azure" - - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") - - api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_OPENAI_API_KEY") - or get_secret("AZURE_API_KEY") - ) - - azure_ad_token = optional_params.get("extra_body", {}).pop( - "azure_ad_token", None - ) or get_secret("AZURE_AD_TOKEN") - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.AzureOpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - response = azure_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - api_version=api_version, - api_type=api_type, - azure_ad_token=azure_ad_token, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, # type: ignore - client=client, # pass AsyncAzureOpenAI, AzureOpenAI client - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={ - "headers": headers, - "api_version": api_version, - "api_base": api_base, - }, - ) - elif custom_llm_provider == "azure_text": - # azure configs - api_type = get_secret("AZURE_API_TYPE") or "azure" - - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") - - api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_OPENAI_API_KEY") - or get_secret("AZURE_API_KEY") - ) - - azure_ad_token = optional_params.get("extra_body", {}).pop( - "azure_ad_token", None - ) or get_secret("AZURE_AD_TOKEN") - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.AzureOpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - response = azure_text_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - api_version=api_version, - api_type=api_type, - azure_ad_token=azure_ad_token, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, - client=client, # pass AsyncAzureOpenAI, AzureOpenAI client - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={ - "headers": headers, - "api_version": api_version, - "api_base": api_base, - }, - ) + response_obj = self.handle_anthropic_text_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "clarifai": + response_obj = self.handle_clarifai_completion_chunk(chunk) + completion_obj["content"] = response_obj["text"] + elif self.model == "replicate" or self.custom_llm_provider == "replicate": + response_obj = self.handle_replicate_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "together_ai": + response_obj = self.handle_together_ai_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": + response_obj = self.handle_huggingface_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "predibase": + response_obj = self.handle_predibase_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] elif ( - model in litellm.open_ai_chat_completion_models - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "deepinfra" - or custom_llm_provider == "perplexity" - or custom_llm_provider == "groq" - or custom_llm_provider == "deepseek" - or custom_llm_provider == "anyscale" - or custom_llm_provider == "mistral" - or custom_llm_provider == "openai" - or custom_llm_provider == "together_ai" - or custom_llm_provider in litellm.openai_compatible_providers - or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo - ): # allow user to make an openai call with a custom base - # note: if a user sets a custom base - we should ensure this works - # allow for the setting of dynamic and stateful api-bases - api_base = ( - api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) - openai.organization = ( - organization - or litellm.organization - or get_secret("OPENAI_ORGANIZATION") - or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 - ) - # set API KEY - api_key = ( - api_key - or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.OpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL + self.custom_llm_provider and self.custom_llm_provider == "baseten" + ): # baseten doesn't provide streaming + completion_obj["content"] = self.handle_baseten_chunk(chunk) + elif ( + self.custom_llm_provider and self.custom_llm_provider == "ai21" + ): # ai21 doesn't provide streaming + response_obj = self.handle_ai21_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "maritalk": + response_obj = self.handle_maritalk_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "vllm": + completion_obj["content"] = chunk[0].outputs[0].text + elif ( + self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha" + ): # aleph alpha doesn't provide streaming + response_obj = self.handle_aleph_alpha_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "nlp_cloud": try: - response = openai_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - model_response=model_response, - print_verbose=print_verbose, - api_key=api_key, - api_base=api_base, - acompletion=acompletion, - logging_obj=logging, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - timeout=timeout, # type: ignore - custom_prompt_dict=custom_prompt_dict, - client=client, # pass AsyncOpenAI, OpenAI client - organization=organization, - custom_llm_provider=custom_llm_provider, - ) + response_obj = self.handle_nlp_cloud_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] except Exception as e: - ## LOGGING - log the original exception returned - logging.post_call( - input=messages, - api_key=api_key, - original_response=str(e), - additional_args={"headers": headers}, - ) - raise e - - if optional_params.get("stream", False): - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={"headers": headers}, - ) - elif ( - custom_llm_provider == "text-completion-openai" - or "ft:babbage-002" in model - or "ft:davinci-002" in model # support for finetuned completion models - ): - openai.api_type = "openai" - - api_base = ( - api_base - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) - - openai.api_version = None - # set API KEY - - api_key = ( - api_key - or litellm.api_key - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.OpenAITextCompletionConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - if litellm.organization: - openai.organization = litellm.organization - - if ( - len(messages) > 0 - and "content" in messages[0] - and type(messages[0]["content"]) == list - ): - # text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content'] - # https://platform.openai.com/docs/api-reference/completions/create - prompt = messages[0]["content"] + if self.received_finish_reason: + raise e + else: + if self.sent_first_chunk is False: + raise Exception("An unknown error occurred with the stream") + self.received_finish_reason = "stop" + elif self.custom_llm_provider == "gemini": + if hasattr(chunk, "parts") == True: + try: + if len(chunk.parts) > 0: + completion_obj["content"] = chunk.parts[0].text + if len(chunk.parts) > 0 and hasattr( + chunk.parts[0], "finish_reason" + ): + self.received_finish_reason = chunk.parts[ + 0 + ].finish_reason.name + except: + if chunk.parts[0].finish_reason.name == "SAFETY": + raise Exception( + f"The response was blocked by VertexAI. {str(chunk)}" + ) else: - prompt = " ".join([message["content"] for message in messages]) # type: ignore + completion_obj["content"] = str(chunk) + elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): + import proto # type: ignore - ## COMPLETION CALL - _response = openai_text_completions.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - api_key=api_key, - api_base=api_base, - acompletion=acompletion, - client=client, # pass AsyncOpenAI, OpenAI client - logging_obj=logging, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - timeout=timeout, # type: ignore - ) + if self.model.startswith("claude-3"): + response_obj = self.handle_vertexai_anthropic_chunk(chunk=chunk) + if response_obj is None: + return + completion_obj["content"] = response_obj["text"] + setattr(model_response, "usage", Usage()) + if response_obj.get("prompt_tokens", None) is not None: + model_response.usage.prompt_tokens = response_obj[ + "prompt_tokens" + ] + if response_obj.get("completion_tokens", None) is not None: + model_response.usage.completion_tokens = response_obj[ + "completion_tokens" + ] + if hasattr(model_response.usage, "prompt_tokens"): + model_response.usage.total_tokens = ( + getattr(model_response.usage, "total_tokens", 0) + + model_response.usage.prompt_tokens + ) + if hasattr(model_response.usage, "completion_tokens"): + model_response.usage.total_tokens = ( + getattr(model_response.usage, "total_tokens", 0) + + model_response.usage.completion_tokens + ) - if ( - optional_params.get("stream", False) == False - and acompletion == False - and text_completion == False - ): - # convert to chat completion response - _response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( - response_object=_response, model_response_object=model_response - ) + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif hasattr(chunk, "candidates") == True: + try: + try: + completion_obj["content"] = chunk.text + except Exception as e: + if "Part has no text." in str(e): + ## check for function calling + function_call = ( + chunk.candidates[0].content.parts[0].function_call + ) - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=_response, - additional_args={"headers": headers}, - ) - response = _response - elif ( - "replicate" in model - or custom_llm_provider == "replicate" - or model in litellm.replicate_models - ): - # Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN") - replicate_key = None - replicate_key = ( - api_key - or litellm.replicate_key - or litellm.api_key - or get_secret("REPLICATE_API_KEY") - or get_secret("REPLICATE_API_TOKEN") - ) + args_dict = {} - api_base = ( - api_base - or litellm.api_base - or get_secret("REPLICATE_API_BASE") - or "https://api.replicate.com/v1" - ) + # Check if it's a RepeatedComposite instance + for key, val in function_call.args.items(): + if isinstance( + val, + proto.marshal.collections.repeated.RepeatedComposite, + ): + # If so, convert to list + args_dict[key] = [v for v in val] + else: + args_dict[key] = val - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - model_response = replicate.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=replicate_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - ) - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=replicate_key, - original_response=model_response, - ) - - response = model_response - - elif custom_llm_provider == "anthropic": - api_key = ( - api_key - or litellm.anthropic_key - or litellm.api_key - or os.environ.get("ANTHROPIC_API_KEY") - ) - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - if (model == "claude-2") or (model == "claude-instant-1"): - # call anthropic /completion, only use this route for claude-2, claude-instant-1 - api_base = ( - api_base - or litellm.api_base - or get_secret("ANTHROPIC_API_BASE") - or "https://api.anthropic.com/v1/complete" - ) - response = anthropic_text_completions.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - headers=headers, - ) + try: + args_str = json.dumps(args_dict) + except Exception as e: + raise e + _delta_obj = litellm.utils.Delta( + content=None, + tool_calls=[ + { + "id": f"call_{str(uuid.uuid4())}", + "function": { + "arguments": args_str, + "name": function_call.name, + }, + "type": "function", + } + ], + ) + _streaming_response = StreamingChoices(delta=_delta_obj) + _model_response = ModelResponse(stream=True) + _model_response.choices = [_streaming_response] + response_obj = {"original_chunk": _model_response} + else: + raise e + if ( + hasattr(chunk.candidates[0], "finish_reason") + and chunk.candidates[0].finish_reason.name + != "FINISH_REASON_UNSPECIFIED" + ): # every non-final chunk in vertex ai has this + self.received_finish_reason = chunk.candidates[ + 0 + ].finish_reason.name + except Exception as e: + if chunk.candidates[0].finish_reason.name == "SAFETY": + raise Exception( + f"The response was blocked by VertexAI. {str(chunk)}" + ) else: - # call /messages - # default route for all anthropic models - api_base = ( - api_base - or litellm.api_base - or get_secret("ANTHROPIC_API_BASE") - or "https://api.anthropic.com/v1/messages" - ) - response = anthropic_chat_completions.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - headers=headers, - ) - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - ) - response = response - elif custom_llm_provider == "nlp_cloud": - nlp_cloud_key = ( - api_key - or litellm.nlp_cloud_key - or get_secret("NLP_CLOUD_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("NLP_CLOUD_API_BASE") - or "https://api.nlpcloud.io/v1/gpu/" - ) - - response = nlp_cloud.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=nlp_cloud_key, - logging_obj=logging, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - response, - model, - custom_llm_provider="nlp_cloud", - logging_obj=logging, - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - ) - - response = response - elif custom_llm_provider == "aleph_alpha": - aleph_alpha_key = ( - api_key - or litellm.aleph_alpha_key - or get_secret("ALEPH_ALPHA_API_KEY") - or get_secret("ALEPHALPHA_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("ALEPH_ALPHA_API_BASE") - or "https://api.aleph-alpha.com/complete" - ) - - model_response = aleph_alpha.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - default_max_tokens_to_sample=litellm.max_tokens, - api_key=aleph_alpha_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="aleph_alpha", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "cohere": - cohere_key = ( - api_key - or litellm.cohere_key - or get_secret("COHERE_API_KEY") - or get_secret("CO_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("COHERE_API_BASE") - or "https://api.cohere.ai/v1/generate" - ) - - model_response = cohere.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=cohere_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="cohere", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "cohere_chat": - cohere_key = ( - api_key - or litellm.cohere_key - or get_secret("COHERE_API_KEY") - or get_secret("CO_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("COHERE_API_BASE") - or "https://api.cohere.ai/v1/chat" - ) - - model_response = cohere_chat.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=cohere_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="cohere_chat", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "maritalk": - maritalk_key = ( - api_key - or litellm.maritalk_key - or get_secret("MARITALK_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("MARITALK_API_BASE") - or "https://chat.maritaca.ai/api/chat/inference" - ) - - model_response = maritalk.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=maritalk_key, - logging_obj=logging, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="maritalk", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "huggingface": - custom_llm_provider = "huggingface" - huggingface_key = ( - api_key - or litellm.huggingface_key - or os.environ.get("HF_TOKEN") - or os.environ.get("HUGGINGFACE_API_KEY") - or litellm.api_key - ) - hf_headers = headers or litellm.headers - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - model_response = huggingface.completion( - model=model, - messages=messages, - api_base=api_base, # type: ignore - headers=hf_headers, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=huggingface_key, - acompletion=acompletion, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - timeout=timeout, # type: ignore - ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion is False - ): - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="huggingface", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "oobabooga": - custom_llm_provider = "oobabooga" - model_response = oobabooga.completion( - model=model, - messages=messages, - model_response=model_response, - api_base=api_base, # type: ignore - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - api_key=None, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - ) - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="oobabooga", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "openrouter": - api_base = api_base or litellm.api_base or "https://openrouter.ai/api/v1" - - api_key = ( - api_key - or litellm.api_key - or litellm.openrouter_key - or get_secret("OPENROUTER_API_KEY") - or get_secret("OR_API_KEY") - ) - - openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai" - - openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM" - - headers = ( - headers - or litellm.headers - or { - "HTTP-Referer": openrouter_site_url, - "X-Title": openrouter_app_name, - } - ) - - ## Load Config - config = openrouter.OpenrouterConfig.get_config() - for k, v in config.items(): - if k == "extra_body": - # we use openai 'extra_body' to pass openrouter specific params - transforms, route, models - if "extra_body" in optional_params: - optional_params[k].update(v) - else: - optional_params[k] = v - elif k not in optional_params: - optional_params[k] = v - - data = {"model": model, "messages": messages, **optional_params} - - ## COMPLETION CALL - response = openai_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, # type: ignore - ) - ## LOGGING - logging.post_call( - input=messages, api_key=openai.api_key, original_response=response - ) - elif ( - custom_llm_provider == "together_ai" - or ("togethercomputer" in model) - or (model in litellm.together_ai_models) - ): - """ - Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility - """ - custom_llm_provider = "together_ai" - together_ai_key = ( - api_key - or litellm.togetherai_api_key - or get_secret("TOGETHER_AI_TOKEN") - or get_secret("TOGETHERAI_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("TOGETHERAI_API_BASE") - or "https://api.together.xyz/inference" - ) - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - model_response = together_ai.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=together_ai_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - ) - if ( - "stream_tokens" in optional_params - and optional_params["stream_tokens"] == True - ): - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="together_ai", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "palm": - palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key - - # palm does not support streaming as yet :( - model_response = palm.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=palm_api_key, - logging_obj=logging, - ) - # fake palm streaming - if "stream" in optional_params and optional_params["stream"] == True: - # fake streaming for palm - resp_string = model_response["choices"][0]["message"]["content"] - response = CustomStreamWrapper( - resp_string, model, custom_llm_provider="palm", logging_obj=logging - ) - return response - response = model_response - elif custom_llm_provider == "gemini": - gemini_api_key = ( - api_key - or get_secret("GEMINI_API_KEY") - or get_secret("PALM_API_KEY") # older palm api key should also work - or litellm.api_key - ) - - # palm does not support streaming as yet :( - model_response = gemini.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=gemini_api_key, - logging_obj=logging, - acompletion=acompletion, - custom_prompt_dict=custom_prompt_dict, - ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion == False - ): - response = CustomStreamWrapper( - iter(model_response), - model, - custom_llm_provider="gemini", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "vertex_ai": - vertex_ai_project = ( - optional_params.pop("vertex_project", None) - or optional_params.pop("vertex_ai_project", None) - or litellm.vertex_project - or get_secret("VERTEXAI_PROJECT") - ) - vertex_ai_location = ( - optional_params.pop("vertex_location", None) - or optional_params.pop("vertex_ai_location", None) - or litellm.vertex_location - or get_secret("VERTEXAI_LOCATION") - ) - vertex_credentials = ( - optional_params.pop("vertex_credentials", None) - or optional_params.pop("vertex_ai_credentials", None) - or get_secret("VERTEXAI_CREDENTIALS") - ) - new_params = deepcopy(optional_params) - if "claude-3" in model: - model_response = vertex_ai_anthropic.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=new_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - vertex_credentials=vertex_credentials, - logging_obj=logging, - acompletion=acompletion, - ) - else: - model_response = vertex_ai.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=new_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - vertex_credentials=vertex_credentials, - logging_obj=logging, - acompletion=acompletion, - ) - - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion == False - ): - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="vertex_ai", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "predibase": - tenant_id = ( - optional_params.pop("tenant_id", None) - or optional_params.pop("predibase_tenant_id", None) - or litellm.predibase_tenant_id - or get_secret("PREDIBASE_TENANT_ID") - ) - - api_base = ( - optional_params.pop("api_base", None) - or optional_params.pop("base_url", None) - or litellm.api_base - or get_secret("PREDIBASE_API_BASE") - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.predibase_key - or get_secret("PREDIBASE_API_KEY") - ) - -> model_response = predibase_chat_completions.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - acompletion=acompletion, - api_base=api_base, - custom_prompt_dict=custom_prompt_dict, - api_key=api_key, - tenant_id=tenant_id, - ) + completion_obj["content"] = str(chunk) + elif self.custom_llm_provider == "cohere": + response_obj = self.handle_cohere_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "cohere_chat": + response_obj = self.handle_cohere_chat_chunk(chunk) + if response_obj is None: + return + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "bedrock": + if self.received_finish_reason is not None: + raise StopIteration +> response_obj = self.handle_bedrock_stream(chunk) -../main.py:1813: +../utils.py:11034: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -self = -model = 'llama-3-8b-instruct' -messages = [{'content': 'What is the meaning of life?', 'role': 'user'}] -api_base = None, custom_prompt_dict = {} -model_response = ModelResponse(id='chatcmpl-755fcb98-22ba-46a2-9d6d-1a85b4363e98', choices=[Choices(finish_reason='stop', index=0, mess... role='assistant'))], created=1715301477, model=None, object='chat.completion', system_fingerprint=None, usage=Usage()) -print_verbose = -encoding = , api_key = 'pb_Qg9YbQo7UqqHdu0ozxN_aw' -logging_obj = -optional_params = {'details': True, 'max_new_tokens': 256, 'return_full_text': False} -tenant_id = 'c4768f95', acompletion = False -litellm_params = {'acompletion': False, 'api_base': 'https://serving.app.predibase.com/c4768f95/deployments/v2/llms/llama-3-8b-instruct/generate_stream', 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'completion_call_id': None, ...} -logger_fn = None -headers = {'Authorization': 'Bearer pb_Qg9YbQo7UqqHdu0ozxN_aw', 'content-type': 'application/json'} +self = , chunk = None - def completion( - self, - model: str, - messages: list, - api_base: str, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - api_key: str, - logging_obj, - optional_params: dict, - tenant_id: str, - acompletion=None, - litellm_params=None, - logger_fn=None, - headers: dict = {}, - ) -> Union[ModelResponse, CustomStreamWrapper]: - headers = self.validate_environment(api_key, headers) - completion_url = "" - input_text = "" - base_url = "https://serving.app.predibase.com" - if "https" in model: - completion_url = model - elif api_base: - base_url = api_base - elif "PREDIBASE_API_BASE" in os.environ: - base_url = os.getenv("PREDIBASE_API_BASE", "") - - completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}" - - if optional_params.get("stream", False) == True: - completion_url += "/generate_stream" - else: - completion_url += "/generate" - - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages, - ) - else: - prompt = prompt_factory(model=model, messages=messages) - - ## Load Config - config = litellm.PredibaseConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - stream = optional_params.pop("stream", False) - - data = { - "inputs": prompt, - "parameters": optional_params, - } - input_text = prompt - ## LOGGING - logging_obj.pre_call( - input=input_text, - api_key=api_key, - additional_args={ - "complete_input_dict": data, - "headers": headers, - "api_base": completion_url, - "acompletion": acompletion, - }, - ) - ## COMPLETION CALL - if acompletion is True: - ### ASYNC STREAMING - if stream == True: - return self.async_streaming( - model=model, - messages=messages, - data=data, - api_base=completion_url, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=api_key, - logging_obj=logging_obj, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=headers, - ) # type: ignore - else: - ### ASYNC COMPLETION - return self.async_completion( - model=model, - messages=messages, - data=data, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=api_key, - logging_obj=logging_obj, - optional_params=optional_params, - stream=False, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=headers, - ) # type: ignore - - ### SYNC STREAMING - if stream == True: - response = requests.post( - completion_url, - headers=headers, - data=json.dumps(data), -> stream=optional_params["stream"], - ) -E KeyError: 'stream' + def handle_bedrock_stream(self, chunk): + if "cohere" in self.model or "anthropic" in self.model: + return { + "text": chunk["text"], + "is_finished": chunk["is_finished"], + "finish_reason": chunk["finish_reason"], + } + if hasattr(chunk, "get"): + chunk = chunk.get("chunk") +> chunk_data = json.loads(chunk.get("bytes").decode()) +E AttributeError: 'NoneType' object has no attribute 'get' -../llms/predibase.py:412: KeyError +../utils.py:10648: AttributeError During handling of the above exception, another exception occurred: -sync_mode = True +sync_mode = False, model = 'bedrock/amazon.titan-tg1-large' @pytest.mark.parametrize("sync_mode", [True, False]) + @pytest.mark.parametrize( + "model", + [ + # "bedrock/cohere.command-r-plus-v1:0", + # "anthropic.claude-3-sonnet-20240229-v1:0", + # "anthropic.claude-instant-v1", + # "bedrock/ai21.j2-mid", + # "mistral.mistral-7b-instruct-v0:2", + "bedrock/amazon.titan-tg1-large", + # "meta.llama3-8b-instruct-v1:0", + ], + ) @pytest.mark.asyncio - async def test_completion_predibase_streaming(sync_mode): + async def test_bedrock_httpx_streaming(sync_mode, model): try: litellm.set_verbose = True - if sync_mode: -> response = completion( - model="predibase/llama-3-8b-instruct", - tenant_id="c4768f95", - api_base="https://serving.app.predibase.com", - api_key=os.getenv("PREDIBASE_API_KEY"), - messages=[{"role": "user", "content": "What is the meaning of life?"}], + final_chunk: Optional[litellm.ModelResponse] = None + response: litellm.CustomStreamWrapper = completion( # type: ignore + model=model, + messages=messages, + max_tokens=10, # type: ignore stream=True, ) + complete_response = "" + # Add any assertions here to check the response + has_finish_reason = False + for idx, chunk in enumerate(response): + final_chunk = chunk + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + has_finish_reason = True + break + complete_response += chunk + if has_finish_reason == False: + raise Exception("finish reason not set") + if complete_response.strip() == "": + raise Exception("Empty response received") + else: + response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore + model=model, + messages=messages, + max_tokens=100, # type: ignore + stream=True, + ) + complete_response = "" + # Add any assertions here to check the response + has_finish_reason = False + idx = 0 + final_chunk: Optional[litellm.ModelResponse] = None +> async for chunk in response: -test_streaming.py:317: +test_streaming.py:1094: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -args = () -kwargs = {'api_base': 'https://serving.app.predibase.com', 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'litellm_call_id': 'cf0ea464-1b45-4473-8e55-6bf6809df7a7', 'litellm_logging_obj': , ...} -result = None, start_time = datetime.datetime(2024, 5, 9, 17, 37, 57, 884661) -logging_obj = -call_type = 'completion', model = 'predibase/llama-3-8b-instruct' -k = 'litellm_logging_obj' +self = - @wraps(original_function) - def wrapper(*args, **kwargs): - # DO NOT MOVE THIS. It always needs to run first - # Check if this is an async function. If so only execute the async function - if ( - kwargs.get("acompletion", False) == True - or kwargs.get("aembedding", False) == True - or kwargs.get("aimg_generation", False) == True - or kwargs.get("amoderation", False) == True - or kwargs.get("atext_completion", False) == True - or kwargs.get("atranscription", False) == True - ): - # [OPTIONAL] CHECK MAX RETRIES / REQUEST - if litellm.num_retries_per_request is not None: - # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] - previous_models = kwargs.get("metadata", {}).get( - "previous_models", None - ) - if previous_models is not None: - if litellm.num_retries_per_request <= len(previous_models): - raise Exception(f"Max retries per request hit!") - - # MODEL CALL - result = original_function(*args, **kwargs) - if "stream" in kwargs and kwargs["stream"] == True: - if ( - "complete_response" in kwargs - and kwargs["complete_response"] == True - ): - chunks = [] - for idx, chunk in enumerate(result): - chunks.append(chunk) - return litellm.stream_chunk_builder( - chunks, messages=kwargs.get("messages", None) - ) - else: - return result - return result - - # Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print - print_args_passed_to_litellm(original_function, args, kwargs) - start_time = datetime.datetime.now() - result = None - logging_obj = kwargs.get("litellm_logging_obj", None) - - # only set litellm_call_id if its not in kwargs - call_type = original_function.__name__ - if "litellm_call_id" not in kwargs: - kwargs["litellm_call_id"] = str(uuid.uuid4()) + async def __anext__(self): try: - model = args[0] if len(args) > 0 else kwargs["model"] - except: - model = None if ( - call_type != CallTypes.image_generation.value - and call_type != CallTypes.text_completion.value + self.custom_llm_provider == "openai" + or self.custom_llm_provider == "azure" + or self.custom_llm_provider == "custom_openai" + or self.custom_llm_provider == "text-completion-openai" + or self.custom_llm_provider == "azure_text" + or self.custom_llm_provider == "anthropic" + or self.custom_llm_provider == "anthropic_text" + or self.custom_llm_provider == "huggingface" + or self.custom_llm_provider == "ollama" + or self.custom_llm_provider == "ollama_chat" + or self.custom_llm_provider == "vertex_ai" + or self.custom_llm_provider == "sagemaker" + or self.custom_llm_provider == "gemini" + or self.custom_llm_provider == "replicate" + or self.custom_llm_provider == "cached_response" + or self.custom_llm_provider == "predibase" + or self.custom_llm_provider == "bedrock" + or self.custom_llm_provider in litellm.openai_compatible_endpoints ): - raise ValueError("model param not passed in.") - - try: - if logging_obj is None: - logging_obj, kwargs = function_setup( - original_function.__name__, rules_obj, start_time, *args, **kwargs - ) - kwargs["litellm_logging_obj"] = logging_obj - - # CHECK FOR 'os.environ/' in kwargs - for k, v in kwargs.items(): - if v is not None and isinstance(v, str) and v.startswith("os.environ/"): - kwargs[k] = litellm.get_secret(v) - # [OPTIONAL] CHECK BUDGET - if litellm.max_budget: - if litellm._current_cost > litellm.max_budget: - raise BudgetExceededError( - current_cost=litellm._current_cost, - max_budget=litellm.max_budget, - ) - - # [OPTIONAL] CHECK MAX RETRIES / REQUEST - if litellm.num_retries_per_request is not None: - # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] - previous_models = kwargs.get("metadata", {}).get( - "previous_models", None - ) - if previous_models is not None: - if litellm.num_retries_per_request <= len(previous_models): - raise Exception(f"Max retries per request hit!") - - # [OPTIONAL] CHECK CACHE - print_verbose( - f"SYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache')['no-cache']: {kwargs.get('cache', {}).get('no-cache', False)}" - ) - # if caching is false or cache["no-cache"]==True, don't run this - if ( - ( - ( - ( - kwargs.get("caching", None) is None - and litellm.cache is not None - ) - or kwargs.get("caching", False) == True - ) - and kwargs.get("cache", {}).get("no-cache", False) != True - ) - and kwargs.get("aembedding", False) != True - and kwargs.get("atext_completion", False) != True - and kwargs.get("acompletion", False) != True - and kwargs.get("aimg_generation", False) != True - and kwargs.get("atranscription", False) != True - ): # allow users to control returning cached responses from the completion function - # checking cache - print_verbose(f"INSIDE CHECKING CACHE") - if ( - litellm.cache is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types - ): - print_verbose(f"Checking Cache") - preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = ( - preset_cache_key # for streaming calls, we need to pass the preset_cache_key - ) - cached_result = litellm.cache.get_cache(*args, **kwargs) - if cached_result != None: - if "detail" in cached_result: - # implies an error occurred - pass - else: - call_type = original_function.__name__ - print_verbose( - f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}" - ) - if call_type == CallTypes.completion.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=ModelResponse(), - stream=kwargs.get("stream", False), - ) - if kwargs.get("stream", False) == True: - cached_result = CustomStreamWrapper( - completion_stream=cached_result, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, - ) - elif call_type == CallTypes.embedding.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - response_type="embedding", - ) - - # LOG SUCCESS - cache_hit = True - end_time = datetime.datetime.now() - ( - model, - custom_llm_provider, - dynamic_api_key, - api_base, - ) = litellm.get_llm_provider( - model=model, - custom_llm_provider=kwargs.get( - "custom_llm_provider", None - ), - api_base=kwargs.get("api_base", None), - api_key=kwargs.get("api_key", None), - ) - print_verbose( - f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" - ) - logging_obj.update_environment_variables( - model=model, - user=kwargs.get("user", None), - optional_params={}, - litellm_params={ - "logger_fn": kwargs.get("logger_fn", None), - "acompletion": False, - "metadata": kwargs.get("metadata", {}), - "model_info": kwargs.get("model_info", {}), - "proxy_server_request": kwargs.get( - "proxy_server_request", None - ), - "preset_cache_key": kwargs.get( - "preset_cache_key", None - ), - "stream_response": kwargs.get( - "stream_response", {} - ), - }, - input=kwargs.get("messages", ""), - api_key=kwargs.get("api_key", None), - original_response=str(cached_result), - additional_args=None, - stream=kwargs.get("stream", False), - ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() - return cached_result - - # CHECK MAX TOKENS - if ( - kwargs.get("max_tokens", None) is not None - and model is not None - and litellm.modify_params - == True # user is okay with params being modified - and ( - call_type == CallTypes.acompletion.value - or call_type == CallTypes.completion.value - ) - ): - try: - base_model = model - if kwargs.get("hf_model_name", None) is not None: - base_model = f"huggingface/{kwargs.get('hf_model_name')}" - max_output_tokens = ( - get_max_tokens(model=base_model) or 4096 - ) # assume min context window is 4k tokens - user_max_tokens = kwargs.get("max_tokens") - ## Scenario 1: User limit + prompt > model limit - messages = None - if len(args) > 1: - messages = args[1] - elif kwargs.get("messages", None): - messages = kwargs["messages"] - input_tokens = token_counter(model=base_model, messages=messages) - input_tokens += max( - 0.1 * input_tokens, 10 - ) # give at least a 10 token buffer. token counting can be imprecise. - if input_tokens > max_output_tokens: - pass # allow call to fail normally - elif user_max_tokens + input_tokens > max_output_tokens: - user_max_tokens = max_output_tokens - input_tokens - print_verbose(f"user_max_tokens: {user_max_tokens}") - kwargs["max_tokens"] = int( - round(user_max_tokens) - ) # make sure max tokens is always an int - except Exception as e: - print_verbose(f"Error while checking max token limit: {str(e)}") - # MODEL CALL - result = original_function(*args, **kwargs) - end_time = datetime.datetime.now() - if "stream" in kwargs and kwargs["stream"] == True: - if ( - "complete_response" in kwargs - and kwargs["complete_response"] == True - ): - chunks = [] - for idx, chunk in enumerate(result): - chunks.append(chunk) - return litellm.stream_chunk_builder( - chunks, messages=kwargs.get("messages", None) - ) - else: - return result - elif "acompletion" in kwargs and kwargs["acompletion"] == True: - return result - elif "aembedding" in kwargs and kwargs["aembedding"] == True: - return result - elif "aimg_generation" in kwargs and kwargs["aimg_generation"] == True: - return result - elif "atranscription" in kwargs and kwargs["atranscription"] == True: - return result - - ### POST-CALL RULES ### - post_call_processing(original_response=result, model=model or None) - - # [OPTIONAL] ADD TO CACHE - if ( - litellm.cache is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types - ) and (kwargs.get("cache", {}).get("no-store", False) != True): - litellm.cache.add_cache(result, *args, **kwargs) - - # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated - verbose_logger.info(f"Wrapper: Completed Call, calling success_handler") - threading.Thread( - target=logging_obj.success_handler, args=(result, start_time, end_time) - ).start() - # RETURN RESULT - if hasattr(result, "_hidden_params"): - result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( - "id", None - ) - result._hidden_params["api_base"] = get_api_base( - model=model, - optional_params=getattr(logging_obj, "optional_params", {}), - ) - result._response_ms = ( - end_time - start_time - ).total_seconds() * 1000 # return response latency in ms like openai - return result - except Exception as e: - call_type = original_function.__name__ - if call_type == CallTypes.completion.value: - num_retries = ( - kwargs.get("num_retries", None) or litellm.num_retries or None - ) - litellm.num_retries = ( - None # set retries to None to prevent infinite loops - ) - context_window_fallback_dict = kwargs.get( - "context_window_fallback_dict", {} - ) - - _is_litellm_router_call = "model_group" in kwargs.get( - "metadata", {} - ) # check if call from litellm.router/proxy - if ( - num_retries and not _is_litellm_router_call - ): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying - if ( - isinstance(e, openai.APIError) - or isinstance(e, openai.Timeout) - or isinstance(e, openai.APIConnectionError) + async for chunk in self.completion_stream: + print_verbose(f"value of async chunk: {chunk}") + if chunk == "None" or chunk is None: + raise Exception + elif ( + self.custom_llm_provider == "gemini" + and hasattr(chunk, "parts") + and len(chunk.parts) == 0 ): - kwargs["num_retries"] = num_retries - return litellm.completion_with_retries(*args, **kwargs) - elif ( - isinstance(e, litellm.exceptions.ContextWindowExceededError) - and context_window_fallback_dict - and model in context_window_fallback_dict - ): - if len(args) > 0: - args[0] = context_window_fallback_dict[model] + continue + # chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks. + # __anext__ also calls async_success_handler, which does logging + print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}") + + processed_chunk: Optional[ModelResponse] = self.chunk_creator( + chunk=chunk + ) + print_verbose( + f"PROCESSED ASYNC CHUNK POST CHUNK CREATOR: {processed_chunk}" + ) + if processed_chunk is None: + continue + ## LOGGING + threading.Thread( + target=self.logging_obj.success_handler, args=(processed_chunk,) + ).start() # log response + asyncio.create_task( + self.logging_obj.async_success_handler( + processed_chunk, + ) + ) + self.response_uptil_now += ( + processed_chunk.choices[0].delta.get("content", "") or "" + ) + self.rules.post_call_rules( + input=self.response_uptil_now, model=self.model + ) + print_verbose(f"final returned processed chunk: {processed_chunk}") + return processed_chunk + raise StopAsyncIteration + else: # temporary patch for non-aiohttp async calls + # example - boto3 bedrock llms + while True: + if isinstance(self.completion_stream, str) or isinstance( + self.completion_stream, bytes + ): + chunk = self.completion_stream else: - kwargs["model"] = context_window_fallback_dict[model] - return original_function(*args, **kwargs) + chunk = next(self.completion_stream) + if chunk is not None and chunk != b"": + print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") + processed_chunk: Optional[ModelResponse] = self.chunk_creator( + chunk=chunk + ) + print_verbose( + f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}" + ) + if processed_chunk is None: + continue + ## LOGGING + threading.Thread( + target=self.logging_obj.success_handler, + args=(processed_chunk,), + ).start() # log processed_chunk + asyncio.create_task( + self.logging_obj.async_success_handler( + processed_chunk, + ) + ) + + self.response_uptil_now += ( + processed_chunk.choices[0].delta.get("content", "") or "" + ) + self.rules.post_call_rules( + input=self.response_uptil_now, model=self.model + ) + # RETURN RESULT + return processed_chunk + except StopAsyncIteration: + if self.sent_last_chunk == True: + raise # Re-raise StopIteration + else: + self.sent_last_chunk = True + processed_chunk = self.finish_reason_handler() + ## LOGGING + threading.Thread( + target=self.logging_obj.success_handler, args=(processed_chunk,) + ).start() # log response + asyncio.create_task( + self.logging_obj.async_success_handler( + processed_chunk, + ) + ) + return processed_chunk + except StopIteration: + if self.sent_last_chunk == True: + raise StopAsyncIteration + else: + self.sent_last_chunk = True + processed_chunk = self.finish_reason_handler() + ## LOGGING + threading.Thread( + target=self.logging_obj.success_handler, args=(processed_chunk,) + ).start() # log response + asyncio.create_task( + self.logging_obj.async_success_handler( + processed_chunk, + ) + ) + return processed_chunk + except Exception as e: traceback_exception = traceback.format_exc() - end_time = datetime.datetime.now() - # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated - if logging_obj: - logging_obj.failure_handler( - e, traceback_exception, start_time, end_time - ) # DO NOT MAKE THREADED - router retry fallback relies on this! - my_thread = threading.Thread( - target=handle_failure, - args=(e, traceback_exception, start_time, end_time, args, kwargs), - ) # don't interrupt execution of main thread - my_thread.start() - if hasattr(e, "message"): - if ( - liteDebuggerClient and liteDebuggerClient.dashboard_url != None - ): # make it easy to get to the debugger logs if you've initialized it - e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}" + # Handle any exceptions that might occur during streaming + asyncio.create_task( + self.logging_obj.async_failure_handler(e, traceback_exception) + ) > raise e -../utils.py:3229: +../utils.py:11630: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -args = () -kwargs = {'api_base': 'https://serving.app.predibase.com', 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'litellm_call_id': 'cf0ea464-1b45-4473-8e55-6bf6809df7a7', 'litellm_logging_obj': , ...} -result = None, start_time = datetime.datetime(2024, 5, 9, 17, 37, 57, 884661) -logging_obj = -call_type = 'completion', model = 'predibase/llama-3-8b-instruct' -k = 'litellm_logging_obj' +self = - @wraps(original_function) - def wrapper(*args, **kwargs): - # DO NOT MOVE THIS. It always needs to run first - # Check if this is an async function. If so only execute the async function - if ( - kwargs.get("acompletion", False) == True - or kwargs.get("aembedding", False) == True - or kwargs.get("aimg_generation", False) == True - or kwargs.get("amoderation", False) == True - or kwargs.get("atext_completion", False) == True - or kwargs.get("atranscription", False) == True - ): - # [OPTIONAL] CHECK MAX RETRIES / REQUEST - if litellm.num_retries_per_request is not None: - # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] - previous_models = kwargs.get("metadata", {}).get( - "previous_models", None - ) - if previous_models is not None: - if litellm.num_retries_per_request <= len(previous_models): - raise Exception(f"Max retries per request hit!") - - # MODEL CALL - result = original_function(*args, **kwargs) - if "stream" in kwargs and kwargs["stream"] == True: - if ( - "complete_response" in kwargs - and kwargs["complete_response"] == True - ): - chunks = [] - for idx, chunk in enumerate(result): - chunks.append(chunk) - return litellm.stream_chunk_builder( - chunks, messages=kwargs.get("messages", None) - ) - else: - return result - return result - - # Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print - print_args_passed_to_litellm(original_function, args, kwargs) - start_time = datetime.datetime.now() - result = None - logging_obj = kwargs.get("litellm_logging_obj", None) - - # only set litellm_call_id if its not in kwargs - call_type = original_function.__name__ - if "litellm_call_id" not in kwargs: - kwargs["litellm_call_id"] = str(uuid.uuid4()) + async def __anext__(self): try: - model = args[0] if len(args) > 0 else kwargs["model"] - except: - model = None if ( - call_type != CallTypes.image_generation.value - and call_type != CallTypes.text_completion.value + self.custom_llm_provider == "openai" + or self.custom_llm_provider == "azure" + or self.custom_llm_provider == "custom_openai" + or self.custom_llm_provider == "text-completion-openai" + or self.custom_llm_provider == "azure_text" + or self.custom_llm_provider == "anthropic" + or self.custom_llm_provider == "anthropic_text" + or self.custom_llm_provider == "huggingface" + or self.custom_llm_provider == "ollama" + or self.custom_llm_provider == "ollama_chat" + or self.custom_llm_provider == "vertex_ai" + or self.custom_llm_provider == "sagemaker" + or self.custom_llm_provider == "gemini" + or self.custom_llm_provider == "replicate" + or self.custom_llm_provider == "cached_response" + or self.custom_llm_provider == "predibase" + or self.custom_llm_provider == "bedrock" + or self.custom_llm_provider in litellm.openai_compatible_endpoints ): - raise ValueError("model param not passed in.") + async for chunk in self.completion_stream: + print_verbose(f"value of async chunk: {chunk}") + if chunk == "None" or chunk is None: + raise Exception + elif ( + self.custom_llm_provider == "gemini" + and hasattr(chunk, "parts") + and len(chunk.parts) == 0 + ): + continue + # chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks. + # __anext__ also calls async_success_handler, which does logging + print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}") - try: - if logging_obj is None: - logging_obj, kwargs = function_setup( - original_function.__name__, rules_obj, start_time, *args, **kwargs - ) - kwargs["litellm_logging_obj"] = logging_obj - - # CHECK FOR 'os.environ/' in kwargs - for k, v in kwargs.items(): - if v is not None and isinstance(v, str) and v.startswith("os.environ/"): - kwargs[k] = litellm.get_secret(v) - # [OPTIONAL] CHECK BUDGET - if litellm.max_budget: - if litellm._current_cost > litellm.max_budget: - raise BudgetExceededError( - current_cost=litellm._current_cost, - max_budget=litellm.max_budget, +> processed_chunk: Optional[ModelResponse] = self.chunk_creator( + chunk=chunk ) + +../utils.py:11528: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +self = +chunk = {'finish_reason': '', 'is_finished': False, 'text': '\nHello, I am an AI model developed by Amazon Titan Foundation Mo...able of understanding and generating human-like text. My development has been focused on continuously improving my pe'} + + def chunk_creator(self, chunk): + model_response = self.model_response_creator() + response_obj = {} + try: + # return this for all models + completion_obj = {"content": ""} + if self.custom_llm_provider and self.custom_llm_provider == "anthropic": + response_obj = self.handle_anthropic_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif ( + self.custom_llm_provider + and self.custom_llm_provider == "anthropic_text" + ): + response_obj = self.handle_anthropic_text_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "clarifai": + response_obj = self.handle_clarifai_completion_chunk(chunk) + completion_obj["content"] = response_obj["text"] + elif self.model == "replicate" or self.custom_llm_provider == "replicate": + response_obj = self.handle_replicate_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "together_ai": + response_obj = self.handle_together_ai_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": + response_obj = self.handle_huggingface_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "predibase": + response_obj = self.handle_predibase_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif ( + self.custom_llm_provider and self.custom_llm_provider == "baseten" + ): # baseten doesn't provide streaming + completion_obj["content"] = self.handle_baseten_chunk(chunk) + elif ( + self.custom_llm_provider and self.custom_llm_provider == "ai21" + ): # ai21 doesn't provide streaming + response_obj = self.handle_ai21_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "maritalk": + response_obj = self.handle_maritalk_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "vllm": + completion_obj["content"] = chunk[0].outputs[0].text + elif ( + self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha" + ): # aleph alpha doesn't provide streaming + response_obj = self.handle_aleph_alpha_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "nlp_cloud": + try: + response_obj = self.handle_nlp_cloud_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + except Exception as e: + if self.received_finish_reason: + raise e + else: + if self.sent_first_chunk is False: + raise Exception("An unknown error occurred with the stream") + self.received_finish_reason = "stop" + elif self.custom_llm_provider == "gemini": + if hasattr(chunk, "parts") == True: + try: + if len(chunk.parts) > 0: + completion_obj["content"] = chunk.parts[0].text + if len(chunk.parts) > 0 and hasattr( + chunk.parts[0], "finish_reason" + ): + self.received_finish_reason = chunk.parts[ + 0 + ].finish_reason.name + except: + if chunk.parts[0].finish_reason.name == "SAFETY": + raise Exception( + f"The response was blocked by VertexAI. {str(chunk)}" + ) + else: + completion_obj["content"] = str(chunk) + elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): + import proto # type: ignore - # [OPTIONAL] CHECK MAX RETRIES / REQUEST - if litellm.num_retries_per_request is not None: - # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] - previous_models = kwargs.get("metadata", {}).get( - "previous_models", None - ) - if previous_models is not None: - if litellm.num_retries_per_request <= len(previous_models): - raise Exception(f"Max retries per request hit!") - - # [OPTIONAL] CHECK CACHE - print_verbose( - f"SYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache')['no-cache']: {kwargs.get('cache', {}).get('no-cache', False)}" - ) - # if caching is false or cache["no-cache"]==True, don't run this - if ( - ( - ( - ( - kwargs.get("caching", None) is None - and litellm.cache is not None + if self.model.startswith("claude-3"): + response_obj = self.handle_vertexai_anthropic_chunk(chunk=chunk) + if response_obj is None: + return + completion_obj["content"] = response_obj["text"] + setattr(model_response, "usage", Usage()) + if response_obj.get("prompt_tokens", None) is not None: + model_response.usage.prompt_tokens = response_obj[ + "prompt_tokens" + ] + if response_obj.get("completion_tokens", None) is not None: + model_response.usage.completion_tokens = response_obj[ + "completion_tokens" + ] + if hasattr(model_response.usage, "prompt_tokens"): + model_response.usage.total_tokens = ( + getattr(model_response.usage, "total_tokens", 0) + + model_response.usage.prompt_tokens ) - or kwargs.get("caching", False) == True - ) - and kwargs.get("cache", {}).get("no-cache", False) != True - ) - and kwargs.get("aembedding", False) != True - and kwargs.get("atext_completion", False) != True - and kwargs.get("acompletion", False) != True - and kwargs.get("aimg_generation", False) != True - and kwargs.get("atranscription", False) != True - ): # allow users to control returning cached responses from the completion function - # checking cache - print_verbose(f"INSIDE CHECKING CACHE") - if ( - litellm.cache is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types - ): - print_verbose(f"Checking Cache") - preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = ( - preset_cache_key # for streaming calls, we need to pass the preset_cache_key - ) - cached_result = litellm.cache.get_cache(*args, **kwargs) - if cached_result != None: - if "detail" in cached_result: - # implies an error occurred - pass - else: - call_type = original_function.__name__ - print_verbose( - f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}" - ) - if call_type == CallTypes.completion.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=ModelResponse(), - stream=kwargs.get("stream", False), + if hasattr(model_response.usage, "completion_tokens"): + model_response.usage.total_tokens = ( + getattr(model_response.usage, "total_tokens", 0) + + model_response.usage.completion_tokens + ) + + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif hasattr(chunk, "candidates") == True: + try: + try: + completion_obj["content"] = chunk.text + except Exception as e: + if "Part has no text." in str(e): + ## check for function calling + function_call = ( + chunk.candidates[0].content.parts[0].function_call ) - if kwargs.get("stream", False) == True: - cached_result = CustomStreamWrapper( - completion_stream=cached_result, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, + + args_dict = {} + + # Check if it's a RepeatedComposite instance + for key, val in function_call.args.items(): + if isinstance( + val, + proto.marshal.collections.repeated.RepeatedComposite, + ): + # If so, convert to list + args_dict[key] = [v for v in val] + else: + args_dict[key] = val + + try: + args_str = json.dumps(args_dict) + except Exception as e: + raise e + _delta_obj = litellm.utils.Delta( + content=None, + tool_calls=[ + { + "id": f"call_{str(uuid.uuid4())}", + "function": { + "arguments": args_str, + "name": function_call.name, + }, + "type": "function", + } + ], + ) + _streaming_response = StreamingChoices(delta=_delta_obj) + _model_response = ModelResponse(stream=True) + _model_response.choices = [_streaming_response] + response_obj = {"original_chunk": _model_response} + else: + raise e + if ( + hasattr(chunk.candidates[0], "finish_reason") + and chunk.candidates[0].finish_reason.name + != "FINISH_REASON_UNSPECIFIED" + ): # every non-final chunk in vertex ai has this + self.received_finish_reason = chunk.candidates[ + 0 + ].finish_reason.name + except Exception as e: + if chunk.candidates[0].finish_reason.name == "SAFETY": + raise Exception( + f"The response was blocked by VertexAI. {str(chunk)}" + ) + else: + completion_obj["content"] = str(chunk) + elif self.custom_llm_provider == "cohere": + response_obj = self.handle_cohere_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "cohere_chat": + response_obj = self.handle_cohere_chat_chunk(chunk) + if response_obj is None: + return + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "bedrock": + if self.received_finish_reason is not None: + raise StopIteration + response_obj = self.handle_bedrock_stream(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "sagemaker": + print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") + response_obj = self.handle_sagemaker_stream(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "petals": + if len(self.completion_stream) == 0: + if self.received_finish_reason is not None: + raise StopIteration + else: + self.received_finish_reason = "stop" + chunk_size = 30 + new_chunk = self.completion_stream[:chunk_size] + completion_obj["content"] = new_chunk + self.completion_stream = self.completion_stream[chunk_size:] + time.sleep(0.05) + elif self.custom_llm_provider == "palm": + # fake streaming + response_obj = {} + if len(self.completion_stream) == 0: + if self.received_finish_reason is not None: + raise StopIteration + else: + self.received_finish_reason = "stop" + chunk_size = 30 + new_chunk = self.completion_stream[:chunk_size] + completion_obj["content"] = new_chunk + self.completion_stream = self.completion_stream[chunk_size:] + time.sleep(0.05) + elif self.custom_llm_provider == "ollama": + response_obj = self.handle_ollama_stream(chunk) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "ollama_chat": + response_obj = self.handle_ollama_chat_stream(chunk) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "cloudflare": + response_obj = self.handle_cloudlfare_stream(chunk) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "watsonx": + response_obj = self.handle_watsonx_stream(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "text-completion-openai": + response_obj = self.handle_openai_text_completion_chunk(chunk) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + if ( + self.stream_options + and self.stream_options.get("include_usage", False) == True + ): + model_response.usage = response_obj["usage"] + elif self.custom_llm_provider == "azure_text": + response_obj = self.handle_azure_text_completion_chunk(chunk) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "cached_response": + response_obj = { + "text": chunk.choices[0].delta.content, + "is_finished": True, + "finish_reason": chunk.choices[0].finish_reason, + "original_chunk": chunk, + } + + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if hasattr(chunk, "id"): + model_response.id = chunk.id + self.response_id = chunk.id + if hasattr(chunk, "system_fingerprint"): + self.system_fingerprint = chunk.system_fingerprint + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + else: # openai / azure chat model + if self.custom_llm_provider == "azure": + if hasattr(chunk, "model"): + # for azure, we need to pass the model from the orignal chunk + self.model = chunk.model + response_obj = self.handle_openai_chat_completion_chunk(chunk) + if response_obj == None: + return + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + if response_obj["finish_reason"] == "error": + raise Exception( + "Mistral API raised a streaming error - finish_reason: error, no content string given." + ) + self.received_finish_reason = response_obj["finish_reason"] + if response_obj.get("original_chunk", None) is not None: + if hasattr(response_obj["original_chunk"], "id"): + model_response.id = response_obj["original_chunk"].id + self.response_id = model_response.id + if hasattr(response_obj["original_chunk"], "system_fingerprint"): + model_response.system_fingerprint = response_obj[ + "original_chunk" + ].system_fingerprint + self.system_fingerprint = response_obj[ + "original_chunk" + ].system_fingerprint + if response_obj["logprobs"] is not None: + model_response.choices[0].logprobs = response_obj["logprobs"] + + if ( + self.stream_options is not None + and self.stream_options["include_usage"] == True + ): + model_response.usage = response_obj["usage"] + + model_response.model = self.model + print_verbose( + f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}" + ) + ## FUNCTION CALL PARSING + if ( + response_obj is not None + and response_obj.get("original_chunk", None) is not None + ): # function / tool calling branch - only set for openai/azure compatible endpoints + # enter this branch when no content has been passed in response + original_chunk = response_obj.get("original_chunk", None) + model_response.id = original_chunk.id + self.response_id = original_chunk.id + if len(original_chunk.choices) > 0: + if ( + original_chunk.choices[0].delta.function_call is not None + or original_chunk.choices[0].delta.tool_calls is not None + ): + try: + delta = original_chunk.choices[0].delta + model_response.system_fingerprint = ( + original_chunk.system_fingerprint + ) + ## AZURE - check if arguments is not None + if ( + original_chunk.choices[0].delta.function_call + is not None + ): + if ( + getattr( + original_chunk.choices[0].delta.function_call, + "arguments", ) - elif call_type == CallTypes.embedding.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - response_type="embedding", + is None + ): + original_chunk.choices[ + 0 + ].delta.function_call.arguments = "" + elif original_chunk.choices[0].delta.tool_calls is not None: + if isinstance( + original_chunk.choices[0].delta.tool_calls, list + ): + for t in original_chunk.choices[0].delta.tool_calls: + if hasattr(t, "functions") and hasattr( + t.functions, "arguments" + ): + if ( + getattr( + t.function, + "arguments", + ) + is None + ): + t.function.arguments = "" + _json_delta = delta.model_dump() + print_verbose(f"_json_delta: {_json_delta}") + if "role" not in _json_delta or _json_delta["role"] is None: + _json_delta["role"] = ( + "assistant" # mistral's api returns role as None ) - - # LOG SUCCESS - cache_hit = True - end_time = datetime.datetime.now() - ( - model, - custom_llm_provider, - dynamic_api_key, - api_base, - ) = litellm.get_llm_provider( - model=model, - custom_llm_provider=kwargs.get( - "custom_llm_provider", None - ), - api_base=kwargs.get("api_base", None), - api_key=kwargs.get("api_key", None), - ) + if "tool_calls" in _json_delta and isinstance( + _json_delta["tool_calls"], list + ): + for tool in _json_delta["tool_calls"]: + if ( + isinstance(tool, dict) + and "function" in tool + and isinstance(tool["function"], dict) + and ("type" not in tool or tool["type"] is None) + ): + # if function returned but type set to None - mistral's api returns type: None + tool["type"] = "function" + model_response.choices[0].delta = Delta(**_json_delta) + except Exception as e: + traceback.print_exc() + model_response.choices[0].delta = Delta() + else: + try: + delta = dict(original_chunk.choices[0].delta) + print_verbose(f"original delta: {delta}") + model_response.choices[0].delta = Delta(**delta) print_verbose( - f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" + f"new delta: {model_response.choices[0].delta}" ) - logging_obj.update_environment_variables( - model=model, - user=kwargs.get("user", None), - optional_params={}, - litellm_params={ - "logger_fn": kwargs.get("logger_fn", None), - "acompletion": False, - "metadata": kwargs.get("metadata", {}), - "model_info": kwargs.get("model_info", {}), - "proxy_server_request": kwargs.get( - "proxy_server_request", None - ), - "preset_cache_key": kwargs.get( - "preset_cache_key", None - ), - "stream_response": kwargs.get( - "stream_response", {} - ), - }, - input=kwargs.get("messages", ""), - api_key=kwargs.get("api_key", None), - original_response=str(cached_result), - additional_args=None, - stream=kwargs.get("stream", False), - ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() - return cached_result - - # CHECK MAX TOKENS - if ( - kwargs.get("max_tokens", None) is not None - and model is not None - and litellm.modify_params - == True # user is okay with params being modified - and ( - call_type == CallTypes.acompletion.value - or call_type == CallTypes.completion.value - ) - ): - try: - base_model = model - if kwargs.get("hf_model_name", None) is not None: - base_model = f"huggingface/{kwargs.get('hf_model_name')}" - max_output_tokens = ( - get_max_tokens(model=base_model) or 4096 - ) # assume min context window is 4k tokens - user_max_tokens = kwargs.get("max_tokens") - ## Scenario 1: User limit + prompt > model limit - messages = None - if len(args) > 1: - messages = args[1] - elif kwargs.get("messages", None): - messages = kwargs["messages"] - input_tokens = token_counter(model=base_model, messages=messages) - input_tokens += max( - 0.1 * input_tokens, 10 - ) # give at least a 10 token buffer. token counting can be imprecise. - if input_tokens > max_output_tokens: - pass # allow call to fail normally - elif user_max_tokens + input_tokens > max_output_tokens: - user_max_tokens = max_output_tokens - input_tokens - print_verbose(f"user_max_tokens: {user_max_tokens}") - kwargs["max_tokens"] = int( - round(user_max_tokens) - ) # make sure max tokens is always an int - except Exception as e: - print_verbose(f"Error while checking max token limit: {str(e)}") - # MODEL CALL -> result = original_function(*args, **kwargs) - -../utils.py:3123: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -model = 'llama-3-8b-instruct' -messages = [{'content': 'What is the meaning of life?', 'role': 'user'}] -timeout = 600.0, temperature = None, top_p = None, n = None, stream = True -stream_options = None, stop = None, max_tokens = None, presence_penalty = None -frequency_penalty = None, logit_bias = None, user = None, response_format = None -seed = None, tools = None, tool_choice = None, logprobs = None -top_logprobs = None, deployment_id = None, extra_headers = None -functions = None, function_call = None, base_url = None, api_version = None -api_key = 'pb_Qg9YbQo7UqqHdu0ozxN_aw', model_list = None -kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_id': 'cf0ea464-1b45-4473-8e55-6bf6809df7a7', 'litellm_logging_obj': , 'tenant_id': 'c4768f95'} -args = {'acompletion': False, 'api_base': None, 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'api_version': None, ...} -api_base = None, mock_response = None, force_timeout = 600, logger_fn = None -verbose = False, custom_llm_provider = 'predibase' - - @client - def completion( - model: str, - # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create - messages: List = [], - timeout: Optional[Union[float, str, httpx.Timeout]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - n: Optional[int] = None, - stream: Optional[bool] = None, - stream_options: Optional[dict] = None, - stop=None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[dict] = None, - user: Optional[str] = None, - # openai v1.0+ new params - response_format: Optional[dict] = None, - seed: Optional[int] = None, - tools: Optional[List] = None, - tool_choice: Optional[str] = None, - logprobs: Optional[bool] = None, - top_logprobs: Optional[int] = None, - deployment_id=None, - extra_headers: Optional[dict] = None, - # soon to be deprecated params by OpenAI - functions: Optional[List] = None, - function_call: Optional[str] = None, - # set api_base, api_version, api_key - base_url: Optional[str] = None, - api_version: Optional[str] = None, - api_key: Optional[str] = None, - model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. - # Optional liteLLM function params - **kwargs, - ) -> Union[ModelResponse, CustomStreamWrapper]: - """ - Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) - Parameters: - model (str): The name of the language model to use for text completion. see all supported LLMs: https://docs.litellm.ai/docs/providers/ - messages (List): A list of message objects representing the conversation context (default is an empty list). - - OPTIONAL PARAMS - functions (List, optional): A list of functions to apply to the conversation messages (default is an empty list). - function_call (str, optional): The name of the function to call within the conversation (default is an empty string). - temperature (float, optional): The temperature parameter for controlling the randomness of the output (default is 1.0). - top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). - n (int, optional): The number of completions to generate (default is 1). - stream (bool, optional): If True, return a streaming response (default is False). - stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true. - stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. - max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). - presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. - frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. - logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. - user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. - logprobs (bool, optional): Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message - top_logprobs (int, optional): An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. - metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. - api_base (str, optional): Base URL for the API (default is None). - api_version (str, optional): API version (default is None). - api_key (str, optional): API key (default is None). - model_list (list, optional): List of api base, version, keys - extra_headers (dict, optional): Additional headers to include in the request. - - LITELLM Specific Params - mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None). - custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock" - max_retries (int, optional): The number of retries to attempt (default is 0). - Returns: - ModelResponse: A response object containing the generated completion and associated metadata. - - Note: - - This function is used to perform completions() using the specified language model. - - It supports various optional parameters for customizing the completion behavior. - - If 'mock_response' is provided, a mock completion response is returned for testing or debugging. - """ - ######### unpacking kwargs ##################### - args = locals() - api_base = kwargs.get("api_base", None) - mock_response = kwargs.get("mock_response", None) - force_timeout = kwargs.get("force_timeout", 600) ## deprecated - logger_fn = kwargs.get("logger_fn", None) - verbose = kwargs.get("verbose", False) - custom_llm_provider = kwargs.get("custom_llm_provider", None) - litellm_logging_obj = kwargs.get("litellm_logging_obj", None) - id = kwargs.get("id", None) - metadata = kwargs.get("metadata", None) - model_info = kwargs.get("model_info", None) - proxy_server_request = kwargs.get("proxy_server_request", None) - fallbacks = kwargs.get("fallbacks", None) - headers = kwargs.get("headers", None) - num_retries = kwargs.get("num_retries", None) ## deprecated - max_retries = kwargs.get("max_retries", None) - context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) - organization = kwargs.get("organization", None) - ### CUSTOM MODEL COST ### - input_cost_per_token = kwargs.get("input_cost_per_token", None) - output_cost_per_token = kwargs.get("output_cost_per_token", None) - input_cost_per_second = kwargs.get("input_cost_per_second", None) - output_cost_per_second = kwargs.get("output_cost_per_second", None) - ### CUSTOM PROMPT TEMPLATE ### - initial_prompt_value = kwargs.get("initial_prompt_value", None) - roles = kwargs.get("roles", None) - final_prompt_value = kwargs.get("final_prompt_value", None) - bos_token = kwargs.get("bos_token", None) - eos_token = kwargs.get("eos_token", None) - preset_cache_key = kwargs.get("preset_cache_key", None) - hf_model_name = kwargs.get("hf_model_name", None) - supports_system_message = kwargs.get("supports_system_message", None) - ### TEXT COMPLETION CALLS ### - text_completion = kwargs.get("text_completion", False) - atext_completion = kwargs.get("atext_completion", False) - ### ASYNC CALLS ### - acompletion = kwargs.get("acompletion", False) - client = kwargs.get("client", None) - ### Admin Controls ### - no_log = kwargs.get("no-log", False) - ######## end of unpacking kwargs ########### - openai_params = [ - "functions", - "function_call", - "temperature", - "temperature", - "top_p", - "n", - "stream", - "stream_options", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "request_timeout", - "api_base", - "api_version", - "api_key", - "deployment_id", - "organization", - "base_url", - "default_headers", - "timeout", - "response_format", - "seed", - "tools", - "tool_choice", - "max_retries", - "logprobs", - "top_logprobs", - "extra_headers", - ] - litellm_params = [ - "metadata", - "acompletion", - "atext_completion", - "text_completion", - "caching", - "mock_response", - "api_key", - "api_version", - "api_base", - "force_timeout", - "logger_fn", - "verbose", - "custom_llm_provider", - "litellm_logging_obj", - "litellm_call_id", - "use_client", - "id", - "fallbacks", - "azure", - "headers", - "model_list", - "num_retries", - "context_window_fallback_dict", - "retry_policy", - "roles", - "final_prompt_value", - "bos_token", - "eos_token", - "request_timeout", - "complete_response", - "self", - "client", - "rpm", - "tpm", - "max_parallel_requests", - "input_cost_per_token", - "output_cost_per_token", - "input_cost_per_second", - "output_cost_per_second", - "hf_model_name", - "model_info", - "proxy_server_request", - "preset_cache_key", - "caching_groups", - "ttl", - "cache", - "no-log", - "base_model", - "stream_timeout", - "supports_system_message", - "region_name", - "allowed_model_region", - ] - default_params = openai_params + litellm_params - non_default_params = { - k: v for k, v in kwargs.items() if k not in default_params - } # model-specific params - pass them straight to the model/provider - - ### TIMEOUT LOGIC ### - timeout = timeout or kwargs.get("request_timeout", 600) or 600 - # set timeout for 10 minutes by default - - if ( - timeout is not None - and isinstance(timeout, httpx.Timeout) - and supports_httpx_timeout(custom_llm_provider) == False - ): - read_timeout = timeout.read or 600 - timeout = read_timeout # default 10 min timeout - elif timeout is not None and not isinstance(timeout, httpx.Timeout): - timeout = float(timeout) # type: ignore - - try: - if base_url is not None: - api_base = base_url - if max_retries is not None: # openai allows openai.OpenAI(max_retries=3) - num_retries = max_retries - logging = litellm_logging_obj - fallbacks = fallbacks or litellm.model_fallbacks - if fallbacks is not None: - return completion_with_fallbacks(**args) - if model_list is not None: - deployments = [ - m["litellm_params"] for m in model_list if m["model_name"] == model - ] - return batch_completion_models(deployments=deployments, **args) - if litellm.model_alias_map and model in litellm.model_alias_map: - model = litellm.model_alias_map[ - model - ] # update the model to the actual value if an alias has been passed in - model_response = ModelResponse() - setattr(model_response, "usage", litellm.Usage()) - if ( - kwargs.get("azure", False) == True - ): # don't remove flag check, to remain backwards compatible for repos like Codium - custom_llm_provider = "azure" - if deployment_id != None: # azure llms - model = deployment_id - custom_llm_provider = "azure" - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( - model=model, - custom_llm_provider=custom_llm_provider, - api_base=api_base, - api_key=api_key, - ) - if model_response is not None and hasattr(model_response, "_hidden_params"): - model_response._hidden_params["custom_llm_provider"] = custom_llm_provider - model_response._hidden_params["region_name"] = kwargs.get( - "aws_region_name", None - ) # support region-based pricing for bedrock - - ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### - if input_cost_per_token is not None and output_cost_per_token is not None: - print_verbose(f"Registering model={model} in model cost map") - litellm.register_model( - { - f"{custom_llm_provider}/{model}": { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider, - }, - model: { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider, - }, - } - ) - elif ( - input_cost_per_second is not None - ): # time based pricing just needs cost in place - output_cost_per_second = output_cost_per_second - litellm.register_model( - { - f"{custom_llm_provider}/{model}": { - "input_cost_per_second": input_cost_per_second, - "output_cost_per_second": output_cost_per_second, - "litellm_provider": custom_llm_provider, - }, - model: { - "input_cost_per_second": input_cost_per_second, - "output_cost_per_second": output_cost_per_second, - "litellm_provider": custom_llm_provider, - }, - } - ) - ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### - custom_prompt_dict = {} # type: ignore - if ( - initial_prompt_value - or roles - or final_prompt_value - or bos_token - or eos_token - ): - custom_prompt_dict = {model: {}} - if initial_prompt_value: - custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value - if roles: - custom_prompt_dict[model]["roles"] = roles - if final_prompt_value: - custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value - if bos_token: - custom_prompt_dict[model]["bos_token"] = bos_token - if eos_token: - custom_prompt_dict[model]["eos_token"] = eos_token - - if ( - supports_system_message is not None - and isinstance(supports_system_message, bool) - and supports_system_message == False - ): - messages = map_system_message_pt(messages=messages) - model_api_key = get_api_key( - llm_provider=custom_llm_provider, dynamic_api_key=api_key - ) # get the api key from the environment if required for the model - - if dynamic_api_key is not None: - api_key = dynamic_api_key - # check if user passed in any of the OpenAI optional params - optional_params = get_optional_params( - functions=functions, - function_call=function_call, - temperature=temperature, - top_p=top_p, - n=n, - stream=stream, - stream_options=stream_options, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - user=user, - # params to identify the model - model=model, - custom_llm_provider=custom_llm_provider, - response_format=response_format, - seed=seed, - tools=tools, - tool_choice=tool_choice, - max_retries=max_retries, - logprobs=logprobs, - top_logprobs=top_logprobs, - extra_headers=extra_headers, - **non_default_params, - ) - - if litellm.add_function_to_prompt and optional_params.get( - "functions_unsupported_model", None - ): # if user opts to add it to prompt, when API doesn't support function calling - functions_unsupported_model = optional_params.pop( - "functions_unsupported_model" - ) - messages = function_call_prompt( - messages=messages, functions=functions_unsupported_model - ) - - # For logging - save the values of the litellm-specific params passed in - litellm_params = get_litellm_params( - acompletion=acompletion, - api_key=api_key, - force_timeout=force_timeout, - logger_fn=logger_fn, - verbose=verbose, - custom_llm_provider=custom_llm_provider, - api_base=api_base, - litellm_call_id=kwargs.get("litellm_call_id", None), - model_alias_map=litellm.model_alias_map, - completion_call_id=id, - metadata=metadata, - model_info=model_info, - proxy_server_request=proxy_server_request, - preset_cache_key=preset_cache_key, - no_log=no_log, - ) - logging.update_environment_variables( - model=model, - user=user, - optional_params=optional_params, - litellm_params=litellm_params, - ) - if mock_response: - return mock_completion( - model, - messages, - stream=stream, - mock_response=mock_response, - logging=logging, - acompletion=acompletion, - ) - if custom_llm_provider == "azure": - # azure configs - api_type = get_secret("AZURE_API_TYPE") or "azure" - - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") - - api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_OPENAI_API_KEY") - or get_secret("AZURE_API_KEY") - ) - - azure_ad_token = optional_params.get("extra_body", {}).pop( - "azure_ad_token", None - ) or get_secret("AZURE_AD_TOKEN") - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.AzureOpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - response = azure_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - api_version=api_version, - api_type=api_type, - azure_ad_token=azure_ad_token, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, # type: ignore - client=client, # pass AsyncAzureOpenAI, AzureOpenAI client - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={ - "headers": headers, - "api_version": api_version, - "api_base": api_base, - }, - ) - elif custom_llm_provider == "azure_text": - # azure configs - api_type = get_secret("AZURE_API_TYPE") or "azure" - - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") - - api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_OPENAI_API_KEY") - or get_secret("AZURE_API_KEY") - ) - - azure_ad_token = optional_params.get("extra_body", {}).pop( - "azure_ad_token", None - ) or get_secret("AZURE_AD_TOKEN") - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.AzureOpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - response = azure_text_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - api_version=api_version, - api_type=api_type, - azure_ad_token=azure_ad_token, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, - client=client, # pass AsyncAzureOpenAI, AzureOpenAI client - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={ - "headers": headers, - "api_version": api_version, - "api_base": api_base, - }, - ) - elif ( - model in litellm.open_ai_chat_completion_models - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "deepinfra" - or custom_llm_provider == "perplexity" - or custom_llm_provider == "groq" - or custom_llm_provider == "deepseek" - or custom_llm_provider == "anyscale" - or custom_llm_provider == "mistral" - or custom_llm_provider == "openai" - or custom_llm_provider == "together_ai" - or custom_llm_provider in litellm.openai_compatible_providers - or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo - ): # allow user to make an openai call with a custom base - # note: if a user sets a custom base - we should ensure this works - # allow for the setting of dynamic and stateful api-bases - api_base = ( - api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) - openai.organization = ( - organization - or litellm.organization - or get_secret("OPENAI_ORGANIZATION") - or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 - ) - # set API KEY - api_key = ( - api_key - or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.OpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - try: - response = openai_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - model_response=model_response, - print_verbose=print_verbose, - api_key=api_key, - api_base=api_base, - acompletion=acompletion, - logging_obj=logging, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - timeout=timeout, # type: ignore - custom_prompt_dict=custom_prompt_dict, - client=client, # pass AsyncOpenAI, OpenAI client - organization=organization, - custom_llm_provider=custom_llm_provider, - ) - except Exception as e: - ## LOGGING - log the original exception returned - logging.post_call( - input=messages, - api_key=api_key, - original_response=str(e), - additional_args={"headers": headers}, - ) - raise e - - if optional_params.get("stream", False): - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={"headers": headers}, - ) - elif ( - custom_llm_provider == "text-completion-openai" - or "ft:babbage-002" in model - or "ft:davinci-002" in model # support for finetuned completion models - ): - openai.api_type = "openai" - - api_base = ( - api_base - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) - - openai.api_version = None - # set API KEY - - api_key = ( - api_key - or litellm.api_key - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.OpenAITextCompletionConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - if litellm.organization: - openai.organization = litellm.organization - - if ( - len(messages) > 0 - and "content" in messages[0] - and type(messages[0]["content"]) == list - ): - # text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content'] - # https://platform.openai.com/docs/api-reference/completions/create - prompt = messages[0]["content"] + except Exception as e: + model_response.choices[0].delta = Delta() else: - prompt = " ".join([message["content"] for message in messages]) # type: ignore + if ( + self.stream_options is not None + and self.stream_options["include_usage"] == True + ): + return model_response + return + print_verbose( + f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}" + ) + print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}") - ## COMPLETION CALL - _response = openai_text_completions.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - api_key=api_key, - api_base=api_base, - acompletion=acompletion, - client=client, # pass AsyncOpenAI, OpenAI client - logging_obj=logging, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - timeout=timeout, # type: ignore - ) - - if ( - optional_params.get("stream", False) == False - and acompletion == False - and text_completion == False - ): - # convert to chat completion response - _response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( - response_object=_response, model_response_object=model_response - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=_response, - additional_args={"headers": headers}, - ) - response = _response - elif ( - "replicate" in model - or custom_llm_provider == "replicate" - or model in litellm.replicate_models + ## RETURN ARG + if ( + "content" in completion_obj + and isinstance(completion_obj["content"], str) + and len(completion_obj["content"]) == 0 + and hasattr(model_response, "usage") + and hasattr(model_response.usage, "prompt_tokens") ): - # Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN") - replicate_key = None - replicate_key = ( - api_key - or litellm.replicate_key - or litellm.api_key - or get_secret("REPLICATE_API_KEY") - or get_secret("REPLICATE_API_TOKEN") + if self.sent_first_chunk == False: + completion_obj["role"] = "assistant" + self.sent_first_chunk = True + model_response.choices[0].delta = Delta(**completion_obj) + print_verbose(f"returning model_response: {model_response}") + return model_response + elif ( + "content" in completion_obj + and isinstance(completion_obj["content"], str) + and len(completion_obj["content"]) > 0 + ): # cannot set content of an OpenAI Object to be an empty string + hold, model_response_str = self.check_special_tokens( + chunk=completion_obj["content"], + finish_reason=model_response.choices[0].finish_reason, + ) # filter out bos/eos tokens from openai-compatible hf endpoints + print_verbose( + f"hold - {hold}, model_response_str - {model_response_str}" ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("REPLICATE_API_BASE") - or "https://api.replicate.com/v1" - ) - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - model_response = replicate.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=replicate_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - ) - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=replicate_key, - original_response=model_response, - ) - - response = model_response - - elif custom_llm_provider == "anthropic": - api_key = ( - api_key - or litellm.anthropic_key - or litellm.api_key - or os.environ.get("ANTHROPIC_API_KEY") - ) - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - if (model == "claude-2") or (model == "claude-instant-1"): - # call anthropic /completion, only use this route for claude-2, claude-instant-1 - api_base = ( - api_base - or litellm.api_base - or get_secret("ANTHROPIC_API_BASE") - or "https://api.anthropic.com/v1/complete" - ) - response = anthropic_text_completions.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - headers=headers, - ) - else: - # call /messages - # default route for all anthropic models - api_base = ( - api_base - or litellm.api_base - or get_secret("ANTHROPIC_API_BASE") - or "https://api.anthropic.com/v1/messages" - ) - response = anthropic_chat_completions.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - headers=headers, - ) - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - ) - response = response - elif custom_llm_provider == "nlp_cloud": - nlp_cloud_key = ( - api_key - or litellm.nlp_cloud_key - or get_secret("NLP_CLOUD_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("NLP_CLOUD_API_BASE") - or "https://api.nlpcloud.io/v1/gpu/" - ) - - response = nlp_cloud.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=nlp_cloud_key, - logging_obj=logging, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - response, - model, - custom_llm_provider="nlp_cloud", - logging_obj=logging, - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - ) - - response = response - elif custom_llm_provider == "aleph_alpha": - aleph_alpha_key = ( - api_key - or litellm.aleph_alpha_key - or get_secret("ALEPH_ALPHA_API_KEY") - or get_secret("ALEPHALPHA_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("ALEPH_ALPHA_API_BASE") - or "https://api.aleph-alpha.com/complete" - ) - - model_response = aleph_alpha.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - default_max_tokens_to_sample=litellm.max_tokens, - api_key=aleph_alpha_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="aleph_alpha", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "cohere": - cohere_key = ( - api_key - or litellm.cohere_key - or get_secret("COHERE_API_KEY") - or get_secret("CO_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("COHERE_API_BASE") - or "https://api.cohere.ai/v1/generate" - ) - - model_response = cohere.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=cohere_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="cohere", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "cohere_chat": - cohere_key = ( - api_key - or litellm.cohere_key - or get_secret("COHERE_API_KEY") - or get_secret("CO_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("COHERE_API_BASE") - or "https://api.cohere.ai/v1/chat" - ) - - model_response = cohere_chat.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=cohere_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="cohere_chat", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "maritalk": - maritalk_key = ( - api_key - or litellm.maritalk_key - or get_secret("MARITALK_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("MARITALK_API_BASE") - or "https://chat.maritaca.ai/api/chat/inference" - ) - - model_response = maritalk.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=maritalk_key, - logging_obj=logging, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="maritalk", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "huggingface": - custom_llm_provider = "huggingface" - huggingface_key = ( - api_key - or litellm.huggingface_key - or os.environ.get("HF_TOKEN") - or os.environ.get("HUGGINGFACE_API_KEY") - or litellm.api_key - ) - hf_headers = headers or litellm.headers - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - model_response = huggingface.completion( - model=model, - messages=messages, - api_base=api_base, # type: ignore - headers=hf_headers, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=huggingface_key, - acompletion=acompletion, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - timeout=timeout, # type: ignore - ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion is False - ): - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="huggingface", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "oobabooga": - custom_llm_provider = "oobabooga" - model_response = oobabooga.completion( - model=model, - messages=messages, - model_response=model_response, - api_base=api_base, # type: ignore - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - api_key=None, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - ) - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="oobabooga", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "openrouter": - api_base = api_base or litellm.api_base or "https://openrouter.ai/api/v1" - - api_key = ( - api_key - or litellm.api_key - or litellm.openrouter_key - or get_secret("OPENROUTER_API_KEY") - or get_secret("OR_API_KEY") - ) - - openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai" - - openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM" - - headers = ( - headers - or litellm.headers - or { - "HTTP-Referer": openrouter_site_url, - "X-Title": openrouter_app_name, - } - ) - - ## Load Config - config = openrouter.OpenrouterConfig.get_config() - for k, v in config.items(): - if k == "extra_body": - # we use openai 'extra_body' to pass openrouter specific params - transforms, route, models - if "extra_body" in optional_params: - optional_params[k].update(v) + if hold is False: + ## check if openai/azure chunk + original_chunk = response_obj.get("original_chunk", None) + if original_chunk: + model_response.id = original_chunk.id + self.response_id = original_chunk.id + if len(original_chunk.choices) > 0: + choices = [] + for idx, choice in enumerate(original_chunk.choices): + try: + if isinstance(choice, BaseModel): + try: + choice_json = choice.model_dump() + except Exception as e: + choice_json = choice.dict() + choice_json.pop( + "finish_reason", None + ) # for mistral etc. which return a value in their last chunk (not-openai compatible). + print_verbose(f"choice_json: {choice_json}") + choices.append(StreamingChoices(**choice_json)) + except Exception as e: + choices.append(StreamingChoices()) + print_verbose(f"choices in streaming: {choices}") + model_response.choices = choices else: - optional_params[k] = v - elif k not in optional_params: - optional_params[k] = v - - data = {"model": model, "messages": messages, **optional_params} - - ## COMPLETION CALL - response = openai_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, # type: ignore - ) - ## LOGGING - logging.post_call( - input=messages, api_key=openai.api_key, original_response=response - ) - elif ( - custom_llm_provider == "together_ai" - or ("togethercomputer" in model) - or (model in litellm.together_ai_models) - ): - """ - Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility - """ - custom_llm_provider = "together_ai" - together_ai_key = ( - api_key - or litellm.togetherai_api_key - or get_secret("TOGETHER_AI_TOKEN") - or get_secret("TOGETHERAI_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("TOGETHERAI_API_BASE") - or "https://api.together.xyz/inference" - ) - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - model_response = together_ai.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=together_ai_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - ) - if ( - "stream_tokens" in optional_params - and optional_params["stream_tokens"] == True - ): - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="together_ai", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "palm": - palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key - - # palm does not support streaming as yet :( - model_response = palm.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=palm_api_key, - logging_obj=logging, - ) - # fake palm streaming - if "stream" in optional_params and optional_params["stream"] == True: - # fake streaming for palm - resp_string = model_response["choices"][0]["message"]["content"] - response = CustomStreamWrapper( - resp_string, model, custom_llm_provider="palm", logging_obj=logging - ) - return response - response = model_response - elif custom_llm_provider == "gemini": - gemini_api_key = ( - api_key - or get_secret("GEMINI_API_KEY") - or get_secret("PALM_API_KEY") # older palm api key should also work - or litellm.api_key - ) - - # palm does not support streaming as yet :( - model_response = gemini.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=gemini_api_key, - logging_obj=logging, - acompletion=acompletion, - custom_prompt_dict=custom_prompt_dict, - ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion == False - ): - response = CustomStreamWrapper( - iter(model_response), - model, - custom_llm_provider="gemini", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "vertex_ai": - vertex_ai_project = ( - optional_params.pop("vertex_project", None) - or optional_params.pop("vertex_ai_project", None) - or litellm.vertex_project - or get_secret("VERTEXAI_PROJECT") - ) - vertex_ai_location = ( - optional_params.pop("vertex_location", None) - or optional_params.pop("vertex_ai_location", None) - or litellm.vertex_location - or get_secret("VERTEXAI_LOCATION") - ) - vertex_credentials = ( - optional_params.pop("vertex_credentials", None) - or optional_params.pop("vertex_ai_credentials", None) - or get_secret("VERTEXAI_CREDENTIALS") - ) - new_params = deepcopy(optional_params) - if "claude-3" in model: - model_response = vertex_ai_anthropic.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=new_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - vertex_credentials=vertex_credentials, - logging_obj=logging, - acompletion=acompletion, - ) - else: - model_response = vertex_ai.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=new_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - vertex_credentials=vertex_credentials, - logging_obj=logging, - acompletion=acompletion, - ) - - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion == False - ): - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="vertex_ai", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "predibase": - tenant_id = ( - optional_params.pop("tenant_id", None) - or optional_params.pop("predibase_tenant_id", None) - or litellm.predibase_tenant_id - or get_secret("PREDIBASE_TENANT_ID") - ) - - api_base = ( - optional_params.pop("api_base", None) - or optional_params.pop("base_url", None) - or litellm.api_base - or get_secret("PREDIBASE_API_BASE") - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.predibase_key - or get_secret("PREDIBASE_API_KEY") - ) - - model_response = predibase_chat_completions.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - acompletion=acompletion, - api_base=api_base, - custom_prompt_dict=custom_prompt_dict, - api_key=api_key, - tenant_id=tenant_id, - ) - - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion == False - ): - return response - response = model_response - elif custom_llm_provider == "ai21": - custom_llm_provider = "ai21" - ai21_key = ( - api_key - or litellm.ai21_key - or os.environ.get("AI21_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("AI21_API_BASE") - or "https://api.ai21.com/studio/v1/" - ) - - model_response = ai21.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=ai21_key, - logging_obj=logging, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="ai21", - logging_obj=logging, - ) - return response - - ## RESPONSE OBJECT - response = model_response - elif custom_llm_provider == "sagemaker": - # boto3 reads keys from .env - model_response = sagemaker.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - custom_prompt_dict=custom_prompt_dict, - hf_model_name=hf_model_name, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - acompletion=acompletion, - ) - if ( - "stream" in optional_params and optional_params["stream"] == True - ): ## [BETA] - print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER") - from .llms.sagemaker import TokenIterator - - tokenIterator = TokenIterator(model_response, acompletion=acompletion) - response = CustomStreamWrapper( - completion_stream=tokenIterator, - model=model, - custom_llm_provider="sagemaker", - logging_obj=logging, - ) - ## LOGGING - logging.post_call( - input=messages, - api_key=None, - original_response=response, - ) - return response - - ## RESPONSE OBJECT - response = model_response - elif custom_llm_provider == "bedrock": - # boto3 reads keys from .env - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - response = bedrock.completion( - model=model, - messages=messages, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - extra_headers=extra_headers, - timeout=timeout, - ) - - if ( - "stream" in optional_params - and optional_params["stream"] == True - and not isinstance(response, CustomStreamWrapper) - ): - # don't try to access stream object, - if "ai21" in model: - response = CustomStreamWrapper( - response, - model, - custom_llm_provider="bedrock", - logging_obj=logging, + return + model_response.system_fingerprint = ( + original_chunk.system_fingerprint + ) + print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}") + if self.sent_first_chunk == False: + model_response.choices[0].delta["role"] = "assistant" + self.sent_first_chunk = True + elif self.sent_first_chunk == True and hasattr( + model_response.choices[0].delta, "role" + ): + _initial_delta = model_response.choices[ + 0 + ].delta.model_dump() + _initial_delta.pop("role", None) + model_response.choices[0].delta = Delta(**_initial_delta) + print_verbose( + f"model_response.choices[0].delta: {model_response.choices[0].delta}" ) else: - response = CustomStreamWrapper( - iter(response), - model, - custom_llm_provider="bedrock", - logging_obj=logging, - ) - - if optional_params.get("stream", False): - ## LOGGING - logging.post_call( - input=messages, - api_key=None, - original_response=response, - ) - - ## RESPONSE OBJECT - response = response - elif custom_llm_provider == "watsonx": - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - response = watsonx.IBMWatsonXAI().completion( - model=model, - messages=messages, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, # type: ignore - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - timeout=timeout, # type: ignore - ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and not isinstance(response, CustomStreamWrapper) - ): - # don't try to access stream object, - response = CustomStreamWrapper( - iter(response), - model, - custom_llm_provider="watsonx", - logging_obj=logging, - ) - - if optional_params.get("stream", False): - ## LOGGING - logging.post_call( - input=messages, - api_key=None, - original_response=response, - ) - ## RESPONSE OBJECT - response = response - elif custom_llm_provider == "vllm": - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - model_response = vllm.completion( - model=model, - messages=messages, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - ) - - if ( - "stream" in optional_params and optional_params["stream"] == True - ): ## [BETA] - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="vllm", - logging_obj=logging, - ) - return response - - ## RESPONSE OBJECT - response = model_response - elif custom_llm_provider == "ollama": - api_base = ( - litellm.api_base - or api_base - or get_secret("OLLAMA_API_BASE") - or "http://localhost:11434" - ) - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages, - ) + ## else + completion_obj["content"] = model_response_str + if self.sent_first_chunk == False: + completion_obj["role"] = "assistant" + self.sent_first_chunk = True + model_response.choices[0].delta = Delta(**completion_obj) + print_verbose(f"returning model_response: {model_response}") + return model_response else: - prompt = prompt_factory( - model=model, - messages=messages, - custom_llm_provider=custom_llm_provider, - ) - if isinstance(prompt, dict): - # for multimode models - ollama/llava prompt_factory returns a dict { - # "prompt": prompt, - # "images": images - # } - prompt, images = prompt["prompt"], prompt["images"] - optional_params["images"] = images - - ## LOGGING - generator = ollama.get_ollama_response( - api_base, - model, - prompt, - optional_params, - logging_obj=logging, - acompletion=acompletion, - model_response=model_response, - encoding=encoding, - ) - if acompletion is True or optional_params.get("stream", False) == True: - return generator - - response = generator - elif custom_llm_provider == "ollama_chat": - api_base = ( - litellm.api_base - or api_base - or get_secret("OLLAMA_API_BASE") - or "http://localhost:11434" + return + elif self.received_finish_reason is not None: + if self.sent_last_chunk == True: + raise StopIteration + # flush any remaining holding chunk + if len(self.holding_chunk) > 0: + if model_response.choices[0].delta.content is None: + model_response.choices[0].delta.content = self.holding_chunk + else: + model_response.choices[0].delta.content = ( + self.holding_chunk + model_response.choices[0].delta.content + ) + self.holding_chunk = "" + # if delta is None + _is_delta_empty = self.is_delta_empty( + delta=model_response.choices[0].delta ) - api_key = ( - api_key - or litellm.ollama_key - or os.environ.get("OLLAMA_API_KEY") - or litellm.api_key - ) - ## LOGGING - generator = ollama_chat.get_ollama_response( - api_base, - api_key, - model, - messages, - optional_params, - logging_obj=logging, - acompletion=acompletion, - model_response=model_response, - encoding=encoding, - ) - if acompletion is True or optional_params.get("stream", False) == True: - return generator + if _is_delta_empty: + # get any function call arguments + model_response.choices[0].finish_reason = map_finish_reason( + finish_reason=self.received_finish_reason + ) # ensure consistent output to openai + self.sent_last_chunk = True - response = generator - elif custom_llm_provider == "cloudflare": - api_key = ( - api_key - or litellm.cloudflare_api_key - or litellm.api_key - or get_secret("CLOUDFLARE_API_KEY") - ) - account_id = get_secret("CLOUDFLARE_ACCOUNT_ID") - api_base = ( - api_base - or litellm.api_base - or get_secret("CLOUDFLARE_API_BASE") - or f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/" - ) - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - response = cloudflare.completion( - model=model, - messages=messages, - api_base=api_base, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - ) - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - response, - model, - custom_llm_provider="cloudflare", - logging_obj=logging, - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - ) - response = response + return model_response elif ( - custom_llm_provider == "baseten" - or litellm.api_base == "https://app.baseten.co" + model_response.choices[0].delta.tool_calls is not None + or model_response.choices[0].delta.function_call is not None ): - custom_llm_provider = "baseten" - baseten_key = ( - api_key - or litellm.baseten_key - or os.environ.get("BASETEN_API_KEY") - or litellm.api_key - ) - - model_response = baseten.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=baseten_key, - logging_obj=logging, - ) - if inspect.isgenerator(model_response) or ( - "stream" in optional_params and optional_params["stream"] == True - ): - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="baseten", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "petals" or model in litellm.petals_models: - api_base = api_base or litellm.api_base - - custom_llm_provider = "petals" - stream = optional_params.pop("stream", False) - model_response = petals.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - ) - if stream == True: ## [BETA] - # Fake streaming for petals - resp_string = model_response["choices"][0]["message"]["content"] - response = CustomStreamWrapper( - resp_string, - model, - custom_llm_provider="petals", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "custom": - import requests - - url = litellm.api_base or api_base or "" - if url == None or url == "": - raise ValueError( - "api_base not set. Set api_base or litellm.api_base for custom endpoints" - ) - - """ - assume input to custom LLM api bases follow this format: - resp = requests.post( - api_base, - json={ - 'model': 'meta-llama/Llama-2-13b-hf', # model name - 'params': { - 'prompt': ["The capital of France is P"], - 'max_tokens': 32, - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 40, - } - } - ) - - """ - prompt = " ".join([message["content"] for message in messages]) # type: ignore - resp = requests.post( - url, - json={ - "model": model, - "params": { - "prompt": [prompt], - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - "top_k": kwargs.get("top_k", 40), - }, - }, - ) - response_json = resp.json() - """ - assume all responses from custom api_bases of this format: - { - 'data': [ - { - 'prompt': 'The capital of France is P', - 'output': ['The capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France'], - 'params': {'temperature': 0.7, 'top_k': 40, 'top_p': 1}}], - 'message': 'ok' - } - ] - } - """ - string_response = response_json["data"][0]["output"][0] - ## RESPONSE OBJECT - model_response["choices"][0]["message"]["content"] = string_response - model_response["created"] = int(time.time()) - model_response["model"] = model - response = model_response + if self.sent_first_chunk == False: + model_response.choices[0].delta["role"] = "assistant" + self.sent_first_chunk = True + return model_response else: - raise ValueError( - f"Unable to map your input to a model. Check your input - {args}" - ) - return response + return + except StopIteration: + raise StopIteration except Exception as e: - ## Map to OpenAI Exception + traceback_exception = traceback.format_exc() + e.message = str(e) > raise exception_type( - model=model, - custom_llm_provider=custom_llm_provider, + model=self.model, + custom_llm_provider=self.custom_llm_provider, original_exception=e, - completion_kwargs=args, - extra_kwargs=kwargs, ) -../main.py:2287: +../utils.py:11380: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -model = 'llama-3-8b-instruct', original_exception = KeyError('stream') -custom_llm_provider = 'predibase' -completion_kwargs = {'acompletion': False, 'api_base': None, 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'api_version': None, ...} -extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_id': 'cf0ea464-1b45-4473-8e55-6bf6809df7a7', 'litellm_logging_obj': , 'tenant_id': 'c4768f95'} +model = 'amazon.titan-tg1-large' +original_exception = AttributeError("'NoneType' object has no attribute 'get'") +custom_llm_provider = 'bedrock', completion_kwargs = {}, extra_kwargs = {} def exception_type( model, @@ -4071,33 +1087,39 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i # Common Extra information needed for all providers # We pass num retries, api_base, vertex_deployment etc to the exception here ################################################################################ + extra_information = "" + try: + _api_base = litellm.get_api_base( + model=model, optional_params=extra_kwargs + ) + messages = litellm.get_first_chars_messages(kwargs=completion_kwargs) + _vertex_project = extra_kwargs.get("vertex_project") + _vertex_location = extra_kwargs.get("vertex_location") + _metadata = extra_kwargs.get("metadata", {}) or {} + _model_group = _metadata.get("model_group") + _deployment = _metadata.get("deployment") + extra_information = f"\nModel: {model}" + if _api_base: + extra_information += f"\nAPI Base: {_api_base}" + if messages and len(messages) > 0: + extra_information += f"\nMessages: {messages}" - _api_base = litellm.get_api_base(model=model, optional_params=extra_kwargs) - messages = litellm.get_first_chars_messages(kwargs=completion_kwargs) - _vertex_project = extra_kwargs.get("vertex_project") - _vertex_location = extra_kwargs.get("vertex_location") - _metadata = extra_kwargs.get("metadata", {}) or {} - _model_group = _metadata.get("model_group") - _deployment = _metadata.get("deployment") - extra_information = f"\nModel: {model}" - if _api_base: - extra_information += f"\nAPI Base: {_api_base}" - if messages and len(messages) > 0: - extra_information += f"\nMessages: {messages}" + if _model_group is not None: + extra_information += f"\nmodel_group: {_model_group}\n" + if _deployment is not None: + extra_information += f"\ndeployment: {_deployment}\n" + if _vertex_project is not None: + extra_information += f"\nvertex_project: {_vertex_project}\n" + if _vertex_location is not None: + extra_information += f"\nvertex_location: {_vertex_location}\n" - if _model_group is not None: - extra_information += f"\nmodel_group: {_model_group}\n" - if _deployment is not None: - extra_information += f"\ndeployment: {_deployment}\n" - if _vertex_project is not None: - extra_information += f"\nvertex_project: {_vertex_project}\n" - if _vertex_location is not None: - extra_information += f"\nvertex_location: {_vertex_location}\n" - - # on litellm proxy add key name + team to exceptions - extra_information = _add_key_name_and_team_to_alert( - request_info=extra_information, metadata=_metadata - ) + # on litellm proxy add key name + team to exceptions + extra_information = _add_key_name_and_team_to_alert( + request_info=extra_information, metadata=_metadata + ) + except: + # DO NOT LET this Block raising the original exception + pass ################################################################################ # End of Common Extra information Needed for all providers @@ -4110,9 +1132,10 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i if "Request Timeout Error" in error_str or "Request timed out" in error_str: exception_mapping_worked = True raise Timeout( - message=f"APITimeoutError - Request timed out. {extra_information} \n error_str: {error_str}", + message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}", model=model, llm_provider=custom_llm_provider, + litellm_debug_info=extra_information, ) if ( @@ -4139,16 +1162,14 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i + "Exception" ) - if ( - "This model's maximum context length is" in error_str - or "Request too large" in error_str - ): + if "This model's maximum context length is" in error_str: exception_mapping_worked = True raise ContextWindowExceededError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, response=original_exception.response, + litellm_debug_info=extra_information, ) elif ( "invalid_request_error" in error_str @@ -4156,10 +1177,11 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise NotFoundError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, response=original_exception.response, + litellm_debug_info=extra_information, ) elif ( "invalid_request_error" in error_str @@ -4167,10 +1189,11 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise ContentPolicyViolationError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, response=original_exception.response, + litellm_debug_info=extra_information, ) elif ( "invalid_request_error" in error_str @@ -4178,10 +1201,19 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise BadRequestError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, response=original_exception.response, + litellm_debug_info=extra_information, + ) + elif "Request too large" in error_str: + raise RateLimitError( + message=f"{exception_provider} - {message}", + model=model, + llm_provider=custom_llm_provider, + response=original_exception.response, + litellm_debug_info=extra_information, ) elif ( "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" @@ -4189,10 +1221,11 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise AuthenticationError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, response=original_exception.response, + litellm_debug_info=extra_information, ) elif "Mistral API raised a streaming error" in error_str: exception_mapping_worked = True @@ -4201,82 +1234,92 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ) raise APIError( status_code=500, - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, request=_request, + litellm_debug_info=extra_information, ) elif hasattr(original_exception, "status_code"): exception_mapping_worked = True if original_exception.status_code == 401: exception_mapping_worked = True raise AuthenticationError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, response=original_exception.response, + litellm_debug_info=extra_information, ) elif original_exception.status_code == 404: exception_mapping_worked = True raise NotFoundError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, response=original_exception.response, + litellm_debug_info=extra_information, ) elif original_exception.status_code == 408: exception_mapping_worked = True raise Timeout( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, + litellm_debug_info=extra_information, ) elif original_exception.status_code == 422: exception_mapping_worked = True raise BadRequestError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, response=original_exception.response, + litellm_debug_info=extra_information, ) elif original_exception.status_code == 429: exception_mapping_worked = True raise RateLimitError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, response=original_exception.response, + litellm_debug_info=extra_information, ) elif original_exception.status_code == 503: exception_mapping_worked = True raise ServiceUnavailableError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, response=original_exception.response, + litellm_debug_info=extra_information, ) elif original_exception.status_code == 504: # gateway timeout error exception_mapping_worked = True raise Timeout( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, + litellm_debug_info=extra_information, ) else: exception_mapping_worked = True raise APIError( status_code=original_exception.status_code, - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, request=original_exception.request, + litellm_debug_info=extra_information, ) else: # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors raise APIConnectionError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, + litellm_debug_info=extra_information, request=httpx.Request( method="POST", url="https://api.openai.com/v1/" ), @@ -4430,8 +1473,42 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i message=f"ReplicateException - {str(original_exception)}", llm_provider="replicate", model=model, - request=original_exception.request, + request=httpx.Request( + method="POST", + url="https://api.replicate.com/v1/deployments", + ), ) + elif custom_llm_provider == "watsonx": + if "token_quota_reached" in error_str: + exception_mapping_worked = True + raise RateLimitError( + message=f"WatsonxException: Rate Limit Errror - {error_str}", + llm_provider="watsonx", + model=model, + response=original_exception.response, + ) + elif custom_llm_provider == "predibase": + if "authorization denied for" in error_str: + exception_mapping_worked = True + + # Predibase returns the raw API Key in the response - this block ensures it's not returned in the exception + if ( + error_str is not None + and isinstance(error_str, str) + and "bearer" in error_str.lower() + ): + # only keep the first 10 chars after the occurnence of "bearer" + _bearer_token_start_index = error_str.lower().find("bearer") + error_str = error_str[: _bearer_token_start_index + 14] + error_str += "XXXXXXX" + '"' + + raise AuthenticationError( + message=f"PredibaseException: Authentication Error - {error_str}", + llm_provider="predibase", + model=model, + response=original_exception.response, + litellm_debug_info=extra_information, + ) elif custom_llm_provider == "bedrock": if ( "too many tokens" in error_str @@ -4569,10 +1646,11 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise BadRequestError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", response=original_exception.response, + litellm_debug_info=extra_information, ) elif ( "None Unknown Error." in error_str @@ -4580,26 +1658,29 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise APIError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", status_code=500, model=model, llm_provider="vertex_ai", request=original_exception.request, + litellm_debug_info=extra_information, ) elif "403" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", response=original_exception.response, + litellm_debug_info=extra_information, ) elif "The response was blocked." in error_str: exception_mapping_worked = True raise UnprocessableEntityError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", + litellm_debug_info=extra_information, response=httpx.Response( status_code=429, request=httpx.Request( @@ -4616,9 +1697,10 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise RateLimitError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", + litellm_debug_info=extra_information, response=httpx.Response( status_code=429, request=httpx.Request( @@ -4631,18 +1713,20 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i if original_exception.status_code == 400: exception_mapping_worked = True raise BadRequestError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", + litellm_debug_info=extra_information, response=original_exception.response, ) if original_exception.status_code == 500: exception_mapping_worked = True raise APIError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", status_code=500, model=model, llm_provider="vertex_ai", + litellm_debug_info=extra_information, request=original_exception.request, ) elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": @@ -5243,25 +2327,28 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i exception_mapping_worked = True raise APIError( status_code=500, - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, request=httpx.Request(method="POST", url="https://openai.com/"), ) elif "This model's maximum context length is" in error_str: exception_mapping_worked = True raise ContextWindowExceededError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, response=original_exception.response, ) elif "DeploymentNotFound" in error_str: exception_mapping_worked = True raise NotFoundError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, response=original_exception.response, ) elif ( @@ -5273,17 +2360,19 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise ContentPolicyViolationError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, response=original_exception.response, ) elif "invalid_request_error" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, response=original_exception.response, ) elif ( @@ -5292,9 +2381,10 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise AuthenticationError( - message=f"{exception_provider} - {original_exception.message} {extra_information}", + message=f"{exception_provider} - {original_exception.message}", llm_provider=custom_llm_provider, model=model, + litellm_debug_info=extra_information, response=original_exception.response, ) elif hasattr(original_exception, "status_code"): @@ -5302,55 +2392,62 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i if original_exception.status_code == 401: exception_mapping_worked = True raise AuthenticationError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True raise Timeout( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", model=model, + litellm_debug_info=extra_information, llm_provider="azure", ) if original_exception.status_code == 422: exception_mapping_worked = True raise BadRequestError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", + litellm_debug_info=extra_information, response=original_exception.response, ) elif original_exception.status_code == 429: exception_mapping_worked = True raise RateLimitError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", + litellm_debug_info=extra_information, response=original_exception.response, ) elif original_exception.status_code == 503: exception_mapping_worked = True raise ServiceUnavailableError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", + litellm_debug_info=extra_information, response=original_exception.response, ) elif original_exception.status_code == 504: # gateway timeout error exception_mapping_worked = True raise Timeout( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", model=model, + litellm_debug_info=extra_information, llm_provider="azure", ) else: exception_mapping_worked = True raise APIError( status_code=original_exception.status_code, - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", + litellm_debug_info=extra_information, model=model, request=httpx.Request( method="POST", url="https://openai.com/" @@ -5359,9 +2456,10 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i else: # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors raise APIConnectionError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, request=httpx.Request(method="POST", url="https://openai.com/"), ) if ( @@ -5412,13 +2510,12 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i if exception_mapping_worked: > raise e -../utils.py:9353: +../utils.py:9661: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -model = 'llama-3-8b-instruct', original_exception = KeyError('stream') -custom_llm_provider = 'predibase' -completion_kwargs = {'acompletion': False, 'api_base': None, 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'api_version': None, ...} -extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_id': 'cf0ea464-1b45-4473-8e55-6bf6809df7a7', 'litellm_logging_obj': , 'tenant_id': 'c4768f95'} +model = 'amazon.titan-tg1-large' +original_exception = AttributeError("'NoneType' object has no attribute 'get'") +custom_llm_provider = 'bedrock', completion_kwargs = {}, extra_kwargs = {} def exception_type( model, @@ -5450,33 +2547,39 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i # Common Extra information needed for all providers # We pass num retries, api_base, vertex_deployment etc to the exception here ################################################################################ + extra_information = "" + try: + _api_base = litellm.get_api_base( + model=model, optional_params=extra_kwargs + ) + messages = litellm.get_first_chars_messages(kwargs=completion_kwargs) + _vertex_project = extra_kwargs.get("vertex_project") + _vertex_location = extra_kwargs.get("vertex_location") + _metadata = extra_kwargs.get("metadata", {}) or {} + _model_group = _metadata.get("model_group") + _deployment = _metadata.get("deployment") + extra_information = f"\nModel: {model}" + if _api_base: + extra_information += f"\nAPI Base: {_api_base}" + if messages and len(messages) > 0: + extra_information += f"\nMessages: {messages}" - _api_base = litellm.get_api_base(model=model, optional_params=extra_kwargs) - messages = litellm.get_first_chars_messages(kwargs=completion_kwargs) - _vertex_project = extra_kwargs.get("vertex_project") - _vertex_location = extra_kwargs.get("vertex_location") - _metadata = extra_kwargs.get("metadata", {}) or {} - _model_group = _metadata.get("model_group") - _deployment = _metadata.get("deployment") - extra_information = f"\nModel: {model}" - if _api_base: - extra_information += f"\nAPI Base: {_api_base}" - if messages and len(messages) > 0: - extra_information += f"\nMessages: {messages}" + if _model_group is not None: + extra_information += f"\nmodel_group: {_model_group}\n" + if _deployment is not None: + extra_information += f"\ndeployment: {_deployment}\n" + if _vertex_project is not None: + extra_information += f"\nvertex_project: {_vertex_project}\n" + if _vertex_location is not None: + extra_information += f"\nvertex_location: {_vertex_location}\n" - if _model_group is not None: - extra_information += f"\nmodel_group: {_model_group}\n" - if _deployment is not None: - extra_information += f"\ndeployment: {_deployment}\n" - if _vertex_project is not None: - extra_information += f"\nvertex_project: {_vertex_project}\n" - if _vertex_location is not None: - extra_information += f"\nvertex_location: {_vertex_location}\n" - - # on litellm proxy add key name + team to exceptions - extra_information = _add_key_name_and_team_to_alert( - request_info=extra_information, metadata=_metadata - ) + # on litellm proxy add key name + team to exceptions + extra_information = _add_key_name_and_team_to_alert( + request_info=extra_information, metadata=_metadata + ) + except: + # DO NOT LET this Block raising the original exception + pass ################################################################################ # End of Common Extra information Needed for all providers @@ -5489,9 +2592,10 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i if "Request Timeout Error" in error_str or "Request timed out" in error_str: exception_mapping_worked = True raise Timeout( - message=f"APITimeoutError - Request timed out. {extra_information} \n error_str: {error_str}", + message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}", model=model, llm_provider=custom_llm_provider, + litellm_debug_info=extra_information, ) if ( @@ -5518,16 +2622,14 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i + "Exception" ) - if ( - "This model's maximum context length is" in error_str - or "Request too large" in error_str - ): + if "This model's maximum context length is" in error_str: exception_mapping_worked = True raise ContextWindowExceededError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, response=original_exception.response, + litellm_debug_info=extra_information, ) elif ( "invalid_request_error" in error_str @@ -5535,10 +2637,11 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise NotFoundError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, response=original_exception.response, + litellm_debug_info=extra_information, ) elif ( "invalid_request_error" in error_str @@ -5546,10 +2649,11 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise ContentPolicyViolationError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, response=original_exception.response, + litellm_debug_info=extra_information, ) elif ( "invalid_request_error" in error_str @@ -5557,10 +2661,19 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise BadRequestError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, response=original_exception.response, + litellm_debug_info=extra_information, + ) + elif "Request too large" in error_str: + raise RateLimitError( + message=f"{exception_provider} - {message}", + model=model, + llm_provider=custom_llm_provider, + response=original_exception.response, + litellm_debug_info=extra_information, ) elif ( "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" @@ -5568,10 +2681,11 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise AuthenticationError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, response=original_exception.response, + litellm_debug_info=extra_information, ) elif "Mistral API raised a streaming error" in error_str: exception_mapping_worked = True @@ -5580,82 +2694,92 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ) raise APIError( status_code=500, - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, request=_request, + litellm_debug_info=extra_information, ) elif hasattr(original_exception, "status_code"): exception_mapping_worked = True if original_exception.status_code == 401: exception_mapping_worked = True raise AuthenticationError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, response=original_exception.response, + litellm_debug_info=extra_information, ) elif original_exception.status_code == 404: exception_mapping_worked = True raise NotFoundError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, response=original_exception.response, + litellm_debug_info=extra_information, ) elif original_exception.status_code == 408: exception_mapping_worked = True raise Timeout( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, + litellm_debug_info=extra_information, ) elif original_exception.status_code == 422: exception_mapping_worked = True raise BadRequestError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, response=original_exception.response, + litellm_debug_info=extra_information, ) elif original_exception.status_code == 429: exception_mapping_worked = True raise RateLimitError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, response=original_exception.response, + litellm_debug_info=extra_information, ) elif original_exception.status_code == 503: exception_mapping_worked = True raise ServiceUnavailableError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, response=original_exception.response, + litellm_debug_info=extra_information, ) elif original_exception.status_code == 504: # gateway timeout error exception_mapping_worked = True raise Timeout( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, + litellm_debug_info=extra_information, ) else: exception_mapping_worked = True raise APIError( status_code=original_exception.status_code, - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, request=original_exception.request, + litellm_debug_info=extra_information, ) else: # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors raise APIConnectionError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, + litellm_debug_info=extra_information, request=httpx.Request( method="POST", url="https://api.openai.com/v1/" ), @@ -5809,8 +2933,42 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i message=f"ReplicateException - {str(original_exception)}", llm_provider="replicate", model=model, - request=original_exception.request, + request=httpx.Request( + method="POST", + url="https://api.replicate.com/v1/deployments", + ), ) + elif custom_llm_provider == "watsonx": + if "token_quota_reached" in error_str: + exception_mapping_worked = True + raise RateLimitError( + message=f"WatsonxException: Rate Limit Errror - {error_str}", + llm_provider="watsonx", + model=model, + response=original_exception.response, + ) + elif custom_llm_provider == "predibase": + if "authorization denied for" in error_str: + exception_mapping_worked = True + + # Predibase returns the raw API Key in the response - this block ensures it's not returned in the exception + if ( + error_str is not None + and isinstance(error_str, str) + and "bearer" in error_str.lower() + ): + # only keep the first 10 chars after the occurnence of "bearer" + _bearer_token_start_index = error_str.lower().find("bearer") + error_str = error_str[: _bearer_token_start_index + 14] + error_str += "XXXXXXX" + '"' + + raise AuthenticationError( + message=f"PredibaseException: Authentication Error - {error_str}", + llm_provider="predibase", + model=model, + response=original_exception.response, + litellm_debug_info=extra_information, + ) elif custom_llm_provider == "bedrock": if ( "too many tokens" in error_str @@ -5948,10 +3106,11 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise BadRequestError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", response=original_exception.response, + litellm_debug_info=extra_information, ) elif ( "None Unknown Error." in error_str @@ -5959,26 +3118,29 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise APIError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", status_code=500, model=model, llm_provider="vertex_ai", request=original_exception.request, + litellm_debug_info=extra_information, ) elif "403" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", response=original_exception.response, + litellm_debug_info=extra_information, ) elif "The response was blocked." in error_str: exception_mapping_worked = True raise UnprocessableEntityError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", + litellm_debug_info=extra_information, response=httpx.Response( status_code=429, request=httpx.Request( @@ -5995,9 +3157,10 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise RateLimitError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", + litellm_debug_info=extra_information, response=httpx.Response( status_code=429, request=httpx.Request( @@ -6010,18 +3173,20 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i if original_exception.status_code == 400: exception_mapping_worked = True raise BadRequestError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", + litellm_debug_info=extra_information, response=original_exception.response, ) if original_exception.status_code == 500: exception_mapping_worked = True raise APIError( - message=f"VertexAIException - {error_str} {extra_information}", + message=f"VertexAIException - {error_str}", status_code=500, model=model, llm_provider="vertex_ai", + litellm_debug_info=extra_information, request=original_exception.request, ) elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": @@ -6622,25 +3787,28 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i exception_mapping_worked = True raise APIError( status_code=500, - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, request=httpx.Request(method="POST", url="https://openai.com/"), ) elif "This model's maximum context length is" in error_str: exception_mapping_worked = True raise ContextWindowExceededError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, response=original_exception.response, ) elif "DeploymentNotFound" in error_str: exception_mapping_worked = True raise NotFoundError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, response=original_exception.response, ) elif ( @@ -6652,17 +3820,19 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise ContentPolicyViolationError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, response=original_exception.response, ) elif "invalid_request_error" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, response=original_exception.response, ) elif ( @@ -6671,9 +3841,10 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i ): exception_mapping_worked = True raise AuthenticationError( - message=f"{exception_provider} - {original_exception.message} {extra_information}", + message=f"{exception_provider} - {original_exception.message}", llm_provider=custom_llm_provider, model=model, + litellm_debug_info=extra_information, response=original_exception.response, ) elif hasattr(original_exception, "status_code"): @@ -6681,55 +3852,62 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i if original_exception.status_code == 401: exception_mapping_worked = True raise AuthenticationError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True raise Timeout( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", model=model, + litellm_debug_info=extra_information, llm_provider="azure", ) if original_exception.status_code == 422: exception_mapping_worked = True raise BadRequestError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", + litellm_debug_info=extra_information, response=original_exception.response, ) elif original_exception.status_code == 429: exception_mapping_worked = True raise RateLimitError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", + litellm_debug_info=extra_information, response=original_exception.response, ) elif original_exception.status_code == 503: exception_mapping_worked = True raise ServiceUnavailableError( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", + litellm_debug_info=extra_information, response=original_exception.response, ) elif original_exception.status_code == 504: # gateway timeout error exception_mapping_worked = True raise Timeout( - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", model=model, + litellm_debug_info=extra_information, llm_provider="azure", ) else: exception_mapping_worked = True raise APIError( status_code=original_exception.status_code, - message=f"AzureException - {original_exception.message} {extra_information}", + message=f"AzureException - {original_exception.message}", llm_provider="azure", + litellm_debug_info=extra_information, model=model, request=httpx.Request( method="POST", url="https://openai.com/" @@ -6738,9 +3916,10 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i else: # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors raise APIConnectionError( - message=f"{exception_provider} - {message} {extra_information}", + message=f"{exception_provider} - {message}", llm_provider="azure", model=model, + litellm_debug_info=extra_information, request=httpx.Request(method="POST", url="https://openai.com/"), ) if ( @@ -6772,224 +3951,184 @@ extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_i method="POST", url="https://api.openai.com/v1/" ), # stub the request ) -E litellm.exceptions.APIConnectionError: 'stream' +E litellm.exceptions.APIConnectionError: 'NoneType' object has no attribute 'get' -../utils.py:9328: APIConnectionError +../utils.py:9636: APIConnectionError During handling of the above exception, another exception occurred: -sync_mode = True +sync_mode = False, model = 'bedrock/amazon.titan-tg1-large' @pytest.mark.parametrize("sync_mode", [True, False]) + @pytest.mark.parametrize( + "model", + [ + # "bedrock/cohere.command-r-plus-v1:0", + # "anthropic.claude-3-sonnet-20240229-v1:0", + # "anthropic.claude-instant-v1", + # "bedrock/ai21.j2-mid", + # "mistral.mistral-7b-instruct-v0:2", + "bedrock/amazon.titan-tg1-large", + # "meta.llama3-8b-instruct-v1:0", + ], + ) @pytest.mark.asyncio - async def test_completion_predibase_streaming(sync_mode): + async def test_bedrock_httpx_streaming(sync_mode, model): try: litellm.set_verbose = True - if sync_mode: - response = completion( - model="predibase/llama-3-8b-instruct", - tenant_id="c4768f95", - api_base="https://serving.app.predibase.com", - api_key=os.getenv("PREDIBASE_API_KEY"), - messages=[{"role": "user", "content": "What is the meaning of life?"}], + final_chunk: Optional[litellm.ModelResponse] = None + response: litellm.CustomStreamWrapper = completion( # type: ignore + model=model, + messages=messages, + max_tokens=10, # type: ignore stream=True, ) - complete_response = "" - for idx, init_chunk in enumerate(response): - chunk, finished = streaming_format_tests(idx, init_chunk) - complete_response += chunk - custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"] - print(f"custom_llm_provider: {custom_llm_provider}") - assert custom_llm_provider == "predibase" + # Add any assertions here to check the response + has_finish_reason = False + for idx, chunk in enumerate(response): + final_chunk = chunk + chunk, finished = streaming_format_tests(idx, chunk) if finished: - assert isinstance( - init_chunk.choices[0], litellm.utils.StreamingChoices - ) + has_finish_reason = True break + complete_response += chunk + if has_finish_reason == False: + raise Exception("finish reason not set") if complete_response.strip() == "": raise Exception("Empty response received") else: - response = await litellm.acompletion( - model="predibase/llama-3-8b-instruct", - tenant_id="c4768f95", - api_base="https://serving.app.predibase.com", - api_key=os.getenv("PREDIBASE_API_KEY"), - messages=[{"role": "user", "content": "What is the meaning of life?"}], + response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore + model=model, + messages=messages, + max_tokens=100, # type: ignore stream=True, ) - - # await response - complete_response = "" + # Add any assertions here to check the response + has_finish_reason = False idx = 0 - async for init_chunk in response: - chunk, finished = streaming_format_tests(idx, init_chunk) - complete_response += chunk - custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"] - print(f"custom_llm_provider: {custom_llm_provider}") - assert custom_llm_provider == "predibase" - idx += 1 + final_chunk: Optional[litellm.ModelResponse] = None + async for chunk in response: + final_chunk = chunk + chunk, finished = streaming_format_tests(idx, chunk) if finished: - assert isinstance( - init_chunk.choices[0], litellm.utils.StreamingChoices - ) + has_finish_reason = True break + complete_response += chunk + idx += 1 + if has_finish_reason == False: + raise Exception("finish reason not set") if complete_response.strip() == "": raise Exception("Empty response received") - - print(f"complete_response: {complete_response}") - except litellm.Timeout as e: + print(f"completion_response: {complete_response}\n\nFinalChunk: {final_chunk}") + except RateLimitError: pass except Exception as e: > pytest.fail(f"Error occurred: {e}") -E Failed: Error occurred: 'stream' +E Failed: Error occurred: 'NoneType' object has no attribute 'get' -test_streaming.py:373: Failed +test_streaming.py:1110: Failed ---------------------------- Captured stdout setup ----------------------------- ----------------------------- Captured stdout call ----------------------------- Request to litellm: -litellm.completion(model='predibase/llama-3-8b-instruct', tenant_id='c4768f95', api_base='https://serving.app.predibase.com', api_key='pb_Qg9YbQo7UqqHdu0ozxN_aw', messages=[{'role': 'user', 'content': 'What is the meaning of life?'}], stream=True) +litellm.acompletion(model='bedrock/amazon.titan-tg1-large', messages=[{'content': 'Hello, how are you?', 'role': 'user'}], max_tokens=100, stream=True) self.optional_params: {} -SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False -UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model=llama-3-8b-instruct, custom_llm_provider=predibase -Final returned optional params: {'stream': True, 'tenant_id': 'c4768f95'} -self.optional_params: {'stream': True, 'tenant_id': 'c4768f95'} +ASYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache'): None +Final returned optional params: {'maxTokenCount': 100, 'stream': True} +self.optional_params: {'maxTokenCount': 100, 'stream': True}  POST Request Sent from LiteLLM: curl -X POST \ -https://serving.app.predibase.com/c4768f95/deployments/v2/llms/llama-3-8b-instruct/generate_stream \ --H 'content-type: application/json' -H 'Authorization: Bearer pb_Qg********************' \ --d '{'inputs': 'What is the meaning of life?', 'parameters': {'details': True, 'max_new_tokens': 256, 'return_full_text': False}}' +https://bedrock-runtime.us-west-2.amazonaws.com/model/amazon.titan-tg1-large/invoke-with-response-stream \ +-H 'Content-Type: application/json' -H 'X-Amz-Date: 20240517T053236Z' -H 'Authorization: AWS4-HMAC-SHA256 Credential=AKIA45ZGR4NCKSABWA6O/20240517/us-west-2/bedrock/aws4_request, SignedHeaders=content-type;host;x-amz-date, Signature=128337479260a5d917f2dd0656a6d57d1662a6c8819f********************' -H 'Content-Length: 84' \ +-d '{"inputText": "Hello, how are you?", "textGenerationConfig": {"maxTokenCount": 100}}'  +value of async chunk: {'text': '\nHello, I am an AI model developed by Amazon Titan Foundation Models. I have been trained on vast amounts of data, making me capable of understanding and generating human-like text. My development has been focused on continuously improving my pe', 'is_finished': False, 'finish_reason': ''} +PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {'text': '\nHello, I am an AI model developed by Amazon Titan Foundation Models. I have been trained on vast amounts of data, making me capable of understanding and generating human-like text. My development has been focused on continuously improving my pe', 'is_finished': False, 'finish_reason': ''} Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. + +Provider List: https://docs.litellm.ai/docs/providers + Logging Details: logger_fn - None | callable(logger_fn) - False -Logging Details LiteLLM-Failure Call -self.failure_callback: [] =============================== warnings summary =============================== ../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 25 warnings /opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning) -../proxy/_types.py:219 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:219: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:255 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:255: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:306 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:306: PydanticDeprecatedSince20: `pydantic.config.Extra` is deprecated, use literal values instead (e.g. `extra='allow'`). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:342 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:342: PydanticDeprecatedSince20: `pydantic.config.Extra` is deprecated, use literal values instead (e.g. `extra='allow'`). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ extra = Extra.allow # Allow extra fields -../proxy/_types.py:309 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:309: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:345 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:345: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:338 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:338: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:374 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:374: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:385 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:385: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:421 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:421: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:454 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:454: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:490 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:490: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:474 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:474: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:510 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:510: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:487 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:487: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:523 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:523: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:532 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:532: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:568 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:568: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:569 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:569: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:605 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:605: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:864 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:864: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:923 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:923: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:891 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:891: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:950 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:950: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:912 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:912: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:971 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:971: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../utils.py:39 - /Users/krrishdholakia/Documents/litellm/litellm/utils.py:39: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html - import pkg_resources # type: ignore +../utils.py:60 + /Users/krrishdholakia/Documents/litellm/litellm/utils.py:60: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice. + with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f: -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: 10 warnings - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`. - Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages - declare_namespace(pkg) - -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.cloud')`. - Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages - declare_namespace(pkg) - -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2317 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2317 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2317 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2317: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`. - Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages - declare_namespace(parent) - -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.logging')`. - Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages - declare_namespace(pkg) - -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.iam')`. - Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages - declare_namespace(pkg) - -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('mpl_toolkits')`. - Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages - declare_namespace(pkg) - -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('sphinxcontrib')`. - Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages - declare_namespace(pkg) - -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('zope')`. - Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages - declare_namespace(pkg) - -test_streaming.py::test_completion_predibase_streaming[False] +test_streaming.py::test_bedrock_httpx_streaming[bedrock/amazon.titan-tg1-large-False] /opt/homebrew/lib/python3.11/site-packages/httpx/_content.py:204: DeprecationWarning: Use 'content=<...>' to upload raw bytes/text content. warnings.warn(message, DeprecationWarning) -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html =========================== short test summary info ============================ -FAILED test_streaming.py::test_completion_predibase_streaming[True] - Failed:... -=================== 1 failed, 1 passed, 64 warnings in 5.28s =================== +FAILED test_streaming.py::test_bedrock_httpx_streaming[bedrock/amazon.titan-tg1-large-False] +!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!! +======================== 1 failed, 40 warnings in 3.56s ======================== diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index f3ec308fb..68143f9ac 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2670,6 +2670,9 @@ def response_format_tests(response: litellm.ModelResponse): "anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-instant-v1", "bedrock/ai21.j2-mid", + "mistral.mistral-7b-instruct-v0:2", + "bedrock/amazon.titan-tg1-large", + "meta.llama3-8b-instruct-v1:0", ], ) @pytest.mark.asyncio @@ -2692,7 +2695,7 @@ async def test_completion_bedrock_httpx_models(sync_mode, model): model=model, messages=[{"role": "user", "content": "Hey! how's it going?"}], temperature=0.2, - max_tokens=200, + max_tokens=100, ) assert isinstance(response, litellm.ModelResponse) @@ -2728,24 +2731,6 @@ def test_completion_bedrock_titan_null_response(): pytest.fail(f"An error occurred - {str(e)}") -def test_completion_bedrock_titan(): - try: - response = completion( - model="bedrock/amazon.titan-tg1-large", - messages=messages, - temperature=0.2, - max_tokens=200, - top_p=0.8, - logger_fn=logger_fn, - ) - # Add any assertions here to check the response - print(response) - except RateLimitError: - pass - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - # test_completion_bedrock_titan() diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index e4aa8b135..59f435a7e 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1048,6 +1048,9 @@ async def test_completion_replicate_llama3_streaming(sync_mode): "anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-instant-v1", "bedrock/ai21.j2-mid", + "mistral.mistral-7b-instruct-v0:2", + "bedrock/amazon.titan-tg1-large", + "meta.llama3-8b-instruct-v1:0", ], ) @pytest.mark.asyncio diff --git a/litellm/utils.py b/litellm/utils.py index 51f31a1ff..82c31fe4b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10637,75 +10637,11 @@ class CustomStreamWrapper: raise e def handle_bedrock_stream(self, chunk): - if "cohere" in self.model or "anthropic" in self.model: - return { - "text": chunk["text"], - "is_finished": chunk["is_finished"], - "finish_reason": chunk["finish_reason"], - } - if hasattr(chunk, "get"): - chunk = chunk.get("chunk") - chunk_data = json.loads(chunk.get("bytes").decode()) - else: - chunk_data = json.loads(chunk.decode()) - if chunk_data: - text = "" - is_finished = False - finish_reason = "" - if "outputText" in chunk_data: - text = chunk_data["outputText"] - # ai21 mapping - if "ai21" in self.model: # fake ai21 streaming - text = chunk_data.get("completions")[0].get("data").get("text") - is_finished = True - finish_reason = "stop" - ######## bedrock.anthropic mappings ############### - elif "completion" in chunk_data: # not claude-3 - text = chunk_data["completion"] # bedrock.anthropic - stop_reason = chunk_data.get("stop_reason", None) - if stop_reason != None: - is_finished = True - finish_reason = stop_reason - elif "delta" in chunk_data: - if chunk_data["delta"].get("text", None) is not None: - text = chunk_data["delta"]["text"] - stop_reason = chunk_data["delta"].get("stop_reason", None) - if stop_reason != None: - is_finished = True - finish_reason = stop_reason - ######## bedrock.mistral mappings ############### - elif "outputs" in chunk_data: - if ( - len(chunk_data["outputs"]) == 1 - and chunk_data["outputs"][0].get("text", None) is not None - ): - text = chunk_data["outputs"][0]["text"] - stop_reason = chunk_data.get("stop_reason", None) - if stop_reason != None: - is_finished = True - finish_reason = stop_reason - ######## bedrock.cohere mappings ############### - # meta mapping - elif "generation" in chunk_data: - text = chunk_data["generation"] # bedrock.meta - # cohere mapping - elif "text" in chunk_data: - text = chunk_data["text"] # bedrock.cohere - # cohere mapping for finish reason - elif "finish_reason" in chunk_data: - finish_reason = chunk_data["finish_reason"] - is_finished = True - elif chunk_data.get("completionReason", None): - is_finished = True - finish_reason = chunk_data["completionReason"] - elif chunk.get("error", None): - raise Exception(chunk["error"]) - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } - return "" + return { + "text": chunk["text"], + "is_finished": chunk["is_finished"], + "finish_reason": chunk["finish_reason"], + } def handle_sagemaker_stream(self, chunk): if "data: [DONE]" in chunk: @@ -11508,14 +11444,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "replicate" or self.custom_llm_provider == "cached_response" or self.custom_llm_provider == "predibase" - or ( - self.custom_llm_provider == "bedrock" - and ( - "cohere" in self.model - or "anthropic" in self.model - or "ai21" in self.model - ) - ) + or self.custom_llm_provider == "bedrock" or self.custom_llm_provider in litellm.openai_compatible_endpoints ): async for chunk in self.completion_stream: