From edb10198efd70b9161fbf2981aad957d185af43e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 8 May 2024 21:25:40 -0700 Subject: [PATCH 01/34] feat - add stream_options support litellm --- litellm/main.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/litellm/main.py b/litellm/main.py index bff9886ac..d6d276653 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -187,6 +187,7 @@ async def acompletion( top_p: Optional[float] = None, n: Optional[int] = None, stream: Optional[bool] = None, + stream_options: Optional[dict] = None, stop=None, max_tokens: Optional[int] = None, presence_penalty: Optional[float] = None, @@ -223,6 +224,7 @@ async def acompletion( top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). n (int, optional): The number of completions to generate (default is 1). stream (bool, optional): If True, return a streaming response (default is False). + stream_options (dict, optional): A dictionary containing options for the streaming response. Only use this if stream is True. stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. @@ -260,6 +262,7 @@ async def acompletion( "top_p": top_p, "n": n, "stream": stream, + "stream_options": stream_options, "stop": stop, "max_tokens": max_tokens, "presence_penalty": presence_penalty, @@ -457,6 +460,7 @@ def completion( top_p: Optional[float] = None, n: Optional[int] = None, stream: Optional[bool] = None, + stream_options: Optional[dict] = None, stop=None, max_tokens: Optional[int] = None, presence_penalty: Optional[float] = None, @@ -496,6 +500,7 @@ def completion( top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). n (int, optional): The number of completions to generate (default is 1). stream (bool, optional): If True, return a streaming response (default is False). + stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true. stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. @@ -573,6 +578,7 @@ def completion( "top_p", "n", "stream", + "stream_options", "stop", "max_tokens", "presence_penalty", @@ -783,6 +789,7 @@ def completion( top_p=top_p, n=n, stream=stream, + stream_options=stream_options, stop=stop, max_tokens=max_tokens, presence_penalty=presence_penalty, From 10420516020f5da2e5369035da7898f218d2449e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 8 May 2024 21:52:25 -0700 Subject: [PATCH 02/34] support stream_options for chat completion models --- litellm/llms/openai.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index d516334ac..d542cbe07 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -530,6 +530,7 @@ class OpenAIChatCompletion(BaseLLM): model=model, custom_llm_provider="openai", logging_obj=logging_obj, + stream_options=data.get("stream_options", None), ) return streamwrapper @@ -579,6 +580,7 @@ class OpenAIChatCompletion(BaseLLM): model=model, custom_llm_provider="openai", logging_obj=logging_obj, + stream_options=data.get("stream_options", None), ) return streamwrapper except ( From f2965660dd2222ee74f7cd447f9b4ace70ba8364 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 8 May 2024 21:52:39 -0700 Subject: [PATCH 03/34] test openai stream_options --- litellm/tests/test_streaming.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 271a53dd4..7d639d7a3 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1501,6 +1501,37 @@ def test_openai_chat_completion_complete_response_call(): # test_openai_chat_completion_complete_response_call() +def test_openai_stream_options_call(): + litellm.set_verbose = False + response = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "system", "content": "say GM - we're going to make it "}], + stream=True, + stream_options={"include_usage": True}, + max_tokens=10, + ) + usage = None + chunks = [] + for chunk in response: + print("chunk: ", chunk) + chunks.append(chunk) + + last_chunk = chunks[-1] + print("last chunk: ", last_chunk) + + """ + Assert that: + - Last Chunk includes Usage + - All chunks prior to last chunk have usage=None + """ + + assert last_chunk.usage is not None + assert last_chunk.usage.total_tokens > 0 + assert last_chunk.usage.prompt_tokens > 0 + assert last_chunk.usage.completion_tokens > 0 + + # assert all non last chunks have usage=None + assert all(chunk.usage is None for chunk in chunks[:-1]) def test_openai_text_completion_call(): From 80ca011a642ae6206c293eef15c951c1f74a8e3c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 8 May 2024 21:53:33 -0700 Subject: [PATCH 04/34] support stream_options --- litellm/utils.py | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/litellm/utils.py b/litellm/utils.py index df58db29c..64a644f15 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -612,6 +612,7 @@ class ModelResponse(OpenAIObject): system_fingerprint=None, usage=None, stream=None, + stream_options=None, response_ms=None, hidden_params=None, **params, @@ -658,6 +659,12 @@ class ModelResponse(OpenAIObject): usage = usage elif stream is None or stream == False: usage = Usage() + elif ( + stream == True + and stream_options is not None + and stream_options.get("include_usage") == True + ): + usage = Usage() if hidden_params: self._hidden_params = hidden_params @@ -4839,6 +4846,7 @@ def get_optional_params( top_p=None, n=None, stream=False, + stream_options=None, stop=None, max_tokens=None, presence_penalty=None, @@ -4908,6 +4916,7 @@ def get_optional_params( "top_p": None, "n": None, "stream": None, + "stream_options": None, "stop": None, "max_tokens": None, "presence_penalty": None, @@ -5779,6 +5788,8 @@ def get_optional_params( optional_params["n"] = n if stream is not None: optional_params["stream"] = stream + if stream_options is not None: + optional_params["stream_options"] = stream_options if stop is not None: optional_params["stop"] = stop if max_tokens is not None: @@ -6049,6 +6060,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): "top_p", "n", "stream", + "stream_options", "stop", "max_tokens", "presence_penalty", @@ -9466,7 +9478,12 @@ def get_secret( # replicate/anthropic/cohere class CustomStreamWrapper: def __init__( - self, completion_stream, model, custom_llm_provider=None, logging_obj=None + self, + completion_stream, + model, + custom_llm_provider=None, + logging_obj=None, + stream_options=None, ): self.model = model self.custom_llm_provider = custom_llm_provider @@ -9492,6 +9509,7 @@ class CustomStreamWrapper: self.response_id = None self.logging_loop = None self.rules = Rules() + self.stream_options = stream_options def __iter__(self): return self @@ -9932,6 +9950,7 @@ class CustomStreamWrapper: is_finished = False finish_reason = None logprobs = None + usage = None original_chunk = None # this is used for function/tool calling if len(str_line.choices) > 0: if ( @@ -9966,12 +9985,15 @@ class CustomStreamWrapper: else: logprobs = None + usage = getattr(str_line, "usage", None) + return { "text": text, "is_finished": is_finished, "finish_reason": finish_reason, "logprobs": logprobs, "original_chunk": str_line, + "usage": usage, } except Exception as e: traceback.print_exc() @@ -10274,7 +10296,9 @@ class CustomStreamWrapper: raise e def model_response_creator(self): - model_response = ModelResponse(stream=True, model=self.model) + model_response = ModelResponse( + stream=True, model=self.model, stream_options=self.stream_options + ) if self.response_id is not None: model_response.id = self.response_id else: @@ -10594,6 +10618,12 @@ class CustomStreamWrapper: if response_obj["logprobs"] is not None: model_response.choices[0].logprobs = response_obj["logprobs"] + if ( + self.stream_options is not None + and self.stream_options["include_usage"] == True + ): + model_response.usage = response_obj["usage"] + model_response.model = self.model print_verbose( f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}" @@ -10681,6 +10711,11 @@ class CustomStreamWrapper: except Exception as e: model_response.choices[0].delta = Delta() else: + if ( + self.stream_options is not None + and self.stream_options["include_usage"] == True + ): + return model_response return print_verbose( f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}" From e7e54772ae518258bfb52b5d9b7f612b60aa4750 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 8 May 2024 21:57:25 -0700 Subject: [PATCH 05/34] docs include `stream_options` param --- docs/my-website/docs/completion/input.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index 11ca13121..451deaac4 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -83,6 +83,7 @@ def completion( top_p: Optional[float] = None, n: Optional[int] = None, stream: Optional[bool] = None, + stream_options: Optional[dict] = None, stop=None, max_tokens: Optional[int] = None, presence_penalty: Optional[float] = None, @@ -139,6 +140,10 @@ def completion( - `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message. +- `stream_options` *dict or null (optional)* - Options for streaming response. Only set this when you set `stream: true` + + - `include_usage` *boolean (optional)* - If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value. + - `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens. - `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion. From 8015bc1c4748ee997f5088fa02aa5806523f4a34 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Thu, 9 May 2024 07:44:15 -0700 Subject: [PATCH 06/34] Revert "Add support for async streaming to watsonx provider " --- litellm/llms/prompt_templates/factory.py | 30 ++- litellm/llms/watsonx.py | 296 +++++++---------------- litellm/main.py | 12 +- litellm/utils.py | 13 - 4 files changed, 112 insertions(+), 239 deletions(-) diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 3ae2db701..24a076dd0 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -487,7 +487,7 @@ def format_prompt_togetherai(messages, prompt_format, chat_template): def ibm_granite_pt(messages: list): """ - IBM's Granite chat models uses the template: + IBM's Granite models uses the template: <|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message} See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models @@ -503,13 +503,12 @@ def ibm_granite_pt(messages: list): "pre_message": "<|user|>\n", "post_message": "\n", }, - 'assistant': { - 'pre_message': '<|assistant|>\n', - 'post_message': '\n', + "assistant": { + "pre_message": "<|assistant|>\n", + "post_message": "\n", }, }, - final_prompt_value='<|assistant|>\n', - ) + ).strip() ### ANTHROPIC ### @@ -1525,9 +1524,24 @@ def prompt_factory( return mistral_instruct_pt(messages=messages) elif "meta-llama/llama-3" in model and "instruct" in model: # https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/ - return hf_chat_template( - model="meta-llama/Meta-Llama-3-8B-Instruct", + return custom_prompt( + role_dict={ + "system": { + "pre_message": "<|start_header_id|>system<|end_header_id|>\n", + "post_message": "<|eot_id|>", + }, + "user": { + "pre_message": "<|start_header_id|>user<|end_header_id|>\n", + "post_message": "<|eot_id|>", + }, + "assistant": { + "pre_message": "<|start_header_id|>assistant<|end_header_id|>\n", + "post_message": "<|eot_id|>", + }, + }, messages=messages, + initial_prompt_value="<|begin_of_text|>", + final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n", ) try: if "meta-llama/llama-2" in model and "chat" in model: diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index 082cdb325..99f2d18ba 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -1,13 +1,12 @@ from enum import Enum import json, types, time # noqa: E401 -from contextlib import asynccontextmanager, contextmanager -from typing import AsyncGenerator, Callable, Dict, Generator, Optional, Any, Union, List +from contextlib import contextmanager +from typing import Callable, Dict, Optional, Any, Union, List import httpx # type: ignore import requests # type: ignore import litellm -from litellm.utils import Logging, ModelResponse, Usage, get_secret -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.utils import ModelResponse, get_secret, Usage from .base import BaseLLM from .prompt_templates import factory as ptf @@ -193,7 +192,7 @@ class WatsonXAIEndpoint(str, Enum): class IBMWatsonXAI(BaseLLM): """ - Class to interface with IBM watsonx.ai API for text generation and embeddings. + Class to interface with IBM Watsonx.ai API for text generation and embeddings. Reference: https://cloud.ibm.com/apidocs/watsonx-ai """ @@ -344,7 +343,7 @@ class IBMWatsonXAI(BaseLLM): ) if token is None and api_key is not None: # generate the auth token - if print_verbose is not None: + if print_verbose: print_verbose("Generating IAM token for Watsonx.ai") token = self.generate_iam_token(api_key) elif token is None and api_key is None: @@ -378,9 +377,8 @@ class IBMWatsonXAI(BaseLLM): model_response: ModelResponse, print_verbose: Callable, encoding, - logging_obj: Logging, - optional_params: Optional[dict] = None, - acompletion: bool = None, + logging_obj, + optional_params: dict, litellm_params: Optional[dict] = None, logger_fn=None, timeout: Optional[float] = None, @@ -403,15 +401,13 @@ class IBMWatsonXAI(BaseLLM): prompt = convert_messages_to_prompt( model, messages, provider, custom_prompt_dict ) - - manage_response = self._make_response_manager(async_=(acompletion is True), logging_obj=logging_obj) - def process_text_gen_response(json_resp: dict) -> ModelResponse: - if "results" not in json_resp: - raise WatsonXAIError( - status_code=500, - message=f"Error: Invalid response from Watsonx.ai API: {json_resp}", - ) + def process_text_request(request_params: dict) -> ModelResponse: + with self._manage_response( + request_params, logging_obj=logging_obj, input=prompt, timeout=timeout + ) as resp: + json_resp = resp.json() + generated_text = json_resp["results"][0]["generated_text"] prompt_tokens = json_resp["results"][0]["input_token_count"] completion_tokens = json_resp["results"][0]["generated_token_count"] @@ -430,52 +426,25 @@ class IBMWatsonXAI(BaseLLM): ) return model_response - def handle_text_request(request_params: dict) -> ModelResponse: - with manage_response( - request_params, input=prompt, timeout=timeout, - ) as resp: - json_resp = resp.json() - - return process_text_gen_response(json_resp) - - async def handle_text_request_async(request_params: dict) -> ModelResponse: - async with manage_response( - request_params, input=prompt, timeout=timeout, - ) as resp: - json_resp = resp.json() - return process_text_gen_response(json_resp) - - def handle_stream_request( + def process_stream_request( request_params: dict, ) -> litellm.CustomStreamWrapper: # stream the response - generated chunks will be handled # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream - with manage_response( - request_params, stream=True, input=prompt, timeout=timeout, + with self._manage_response( + request_params, + logging_obj=logging_obj, + stream=True, + input=prompt, + timeout=timeout, ) as resp: - streamwrapper = litellm.CustomStreamWrapper( + response = litellm.CustomStreamWrapper( resp.iter_lines(), model=model, custom_llm_provider="watsonx", logging_obj=logging_obj, ) - return streamwrapper - - async def handle_stream_request_async( - request_params: dict, - ) -> litellm.CustomStreamWrapper: - # stream the response - generated chunks will be handled - # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream - async with manage_response( - request_params, stream=True, input=prompt, timeout=timeout, - ) as resp: - streamwrapper = litellm.CustomStreamWrapper( - resp.aiter_lines(), - model=model, - custom_llm_provider="watsonx", - logging_obj=logging_obj, - ) - return streamwrapper + return response try: ## Get the response from the model @@ -486,18 +455,10 @@ class IBMWatsonXAI(BaseLLM): optional_params=optional_params, print_verbose=print_verbose, ) - if stream and acompletion: - # stream and async text generation - return handle_stream_request_async(req_params) - elif stream: - # streaming text generation - return handle_stream_request(req_params) - elif acompletion: - # async text generation - return handle_text_request_async(req_params) + if stream: + return process_stream_request(req_params) else: - # regular text generation - return handle_text_request(req_params) + return process_text_request(req_params) except WatsonXAIError as e: raise e except Exception as e: @@ -512,7 +473,6 @@ class IBMWatsonXAI(BaseLLM): model_response=None, optional_params=None, encoding=None, - aembedding=None, ): """ Send a text embedding request to the IBM Watsonx.ai API. @@ -547,6 +507,9 @@ class IBMWatsonXAI(BaseLLM): } request_params = dict(version=api_params["api_version"]) url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS + # request = httpx.Request( + # "POST", url, headers=headers, json=payload, params=request_params + # ) req_params = { "method": "POST", "url": url, @@ -554,47 +517,25 @@ class IBMWatsonXAI(BaseLLM): "json": payload, "params": request_params, } - manage_response = self._make_response_manager(async_=(aembedding is True), logging_obj=logging_obj) - - def process_embedding_response(json_resp: dict) -> ModelResponse: - results = json_resp.get("results", []) - embedding_response = [] - for idx, result in enumerate(results): - embedding_response.append( - {"object": "embedding", "index": idx, "embedding": result["embedding"]} - ) - model_response["object"] = "list" - model_response["data"] = embedding_response - model_response["model"] = model - input_tokens = json_resp.get("input_token_count", 0) - model_response.usage = Usage( - prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + with self._manage_response( + req_params, logging_obj=logging_obj, input=input + ) as resp: + json_resp = resp.json() + + results = json_resp.get("results", []) + embedding_response = [] + for idx, result in enumerate(results): + embedding_response.append( + {"object": "embedding", "index": idx, "embedding": result["embedding"]} ) - return model_response - - def handle_embedding_request(request_params: dict) -> ModelResponse: - with manage_response( - request_params, input=input - ) as resp: - json_resp = resp.json() - return process_embedding_response(json_resp) - - async def handle_embedding_request_async(request_params: dict) -> ModelResponse: - async with manage_response( - request_params, input=input - ) as resp: - json_resp = resp.json() - return process_embedding_response(json_resp) - - try: - if aembedding: - return handle_embedding_request_async(req_params) - else: - return handle_embedding_request(req_params) - except WatsonXAIError as e: - raise e - except Exception as e: - raise WatsonXAIError(status_code=500, message=str(e)) + model_response["object"] = "list" + model_response["data"] = embedding_response + model_response["model"] = model + input_tokens = json_resp.get("input_token_count", 0) + model_response.usage = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + ) + return model_response def generate_iam_token(self, api_key=None, **params): headers = {} @@ -616,116 +557,53 @@ class IBMWatsonXAI(BaseLLM): iam_access_token = json_data["access_token"] self.token = iam_access_token return iam_access_token - - def _make_response_manager( - self, - async_: bool, - logging_obj: Logging - ) -> Callable[..., Generator[Union[requests.Response, httpx.Response], None, None]]: - """ - Returns a context manager that manages the response from the request. - if async_ is True, returns an async context manager, otherwise returns a regular context manager. - Usage: - ```python - manage_response = self._make_response_manager(async_=True, logging_obj=logging_obj) - async with manage_response(request_params) as resp: - ... - # or - manage_response = self._make_response_manager(async_=False, logging_obj=logging_obj) - with manage_response(request_params) as resp: - ... - ``` - """ - - def pre_call( - request_params: dict, - input:Optional[Any]=None, - ): - request_str = ( - f"response = {'await ' if async_ else ''}{request_params['method']}(\n" - f"\turl={request_params['url']},\n" - f"\tjson={request_params['json']},\n" - f")" - ) - logging_obj.pre_call( - input=input, - api_key=request_params["headers"].get("Authorization"), - additional_args={ - "complete_input_dict": request_params["json"], - "request_str": request_str, - }, - ) - - def post_call(resp, request_params): + @contextmanager + def _manage_response( + self, + request_params: dict, + logging_obj: Any, + stream: bool = False, + input: Optional[Any] = None, + timeout: Optional[float] = None, + ): + request_str = ( + f"response = {request_params['method']}(\n" + f"\turl={request_params['url']},\n" + f"\tjson={request_params['json']},\n" + f")" + ) + logging_obj.pre_call( + input=input, + api_key=request_params["headers"].get("Authorization"), + additional_args={ + "complete_input_dict": request_params["json"], + "request_str": request_str, + }, + ) + if timeout: + request_params["timeout"] = timeout + try: + if stream: + resp = requests.request( + **request_params, + stream=True, + ) + resp.raise_for_status() + yield resp + else: + resp = requests.request(**request_params) + resp.raise_for_status() + yield resp + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + if not stream: logging_obj.post_call( input=input, api_key=request_params["headers"].get("Authorization"), original_response=json.dumps(resp.json()), additional_args={ "status_code": resp.status_code, - "complete_input_dict": request_params.get("data", request_params.get("json")), + "complete_input_dict": request_params["json"], }, ) - - @contextmanager - def _manage_response( - request_params: dict, - stream: bool = False, - input: Optional[Any] = None, - timeout: float = None, - ) -> Generator[requests.Response, None, None]: - """ - Returns a context manager that yields the response from the request. - """ - pre_call(request_params, input) - if timeout: - request_params["timeout"] = timeout - if stream: - request_params["stream"] = stream - try: - resp = requests.request(**request_params) - resp.raise_for_status() - yield resp - except Exception as e: - raise WatsonXAIError(status_code=500, message=str(e)) - if not stream: - post_call(resp, request_params) - - - @asynccontextmanager - async def _manage_response_async( - request_params: dict, - stream: bool = False, - input: Optional[Any] = None, - timeout: float = None, - ) -> AsyncGenerator[httpx.Response, None]: - pre_call(request_params, input) - if timeout: - request_params["timeout"] = timeout - if stream: - request_params["stream"] = stream - try: - # async with AsyncHTTPHandler(timeout=timeout) as client: - self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=request_params.pop("timeout", 600.0), connect=5.0), - ) - # async_handler.client.verify = False - if "json" in request_params: - request_params['data'] = json.dumps(request_params.pop("json", {})) - method = request_params.pop("method") - if method.upper() == "POST": - resp = await self.async_handler.post(**request_params) - else: - resp = await self.async_handler.get(**request_params) - yield resp - # await async_handler.close() - except Exception as e: - raise WatsonXAIError(status_code=500, message=str(e)) - if not stream: - post_call(resp, request_params) - - if async_: - return _manage_response_async - else: - return _manage_response diff --git a/litellm/main.py b/litellm/main.py index 1298dec4c..99e5ec224 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -73,7 +73,6 @@ from .llms.azure_text import AzureTextCompletion from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic_text import AnthropicTextCompletion from .llms.huggingface_restapi import Huggingface -from .llms.watsonx import IBMWatsonXAI from .llms.prompt_templates.factory import ( prompt_factory, custom_prompt, @@ -110,7 +109,6 @@ anthropic_text_completions = AnthropicTextCompletion() azure_chat_completions = AzureChatCompletion() azure_text_completions = AzureTextCompletion() huggingface = Huggingface() -watsonxai = IBMWatsonXAI() ####### COMPLETION ENDPOINTS ################ @@ -315,7 +313,6 @@ async def acompletion( or custom_llm_provider == "gemini" or custom_llm_provider == "sagemaker" or custom_llm_provider == "anthropic" - or custom_llm_provider == "watsonx" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -1911,7 +1908,7 @@ def completion( response = response elif custom_llm_provider == "watsonx": custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - response = watsonxai.completion( + response = watsonx.IBMWatsonXAI().completion( model=model, messages=messages, custom_prompt_dict=custom_prompt_dict, @@ -1922,8 +1919,7 @@ def completion( logger_fn=logger_fn, encoding=encoding, logging_obj=logging, - acompletion=acompletion, - timeout=timeout, + timeout=timeout, # type: ignore ) if ( "stream" in optional_params @@ -2576,7 +2572,6 @@ async def aembedding(*args, **kwargs): or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "ollama" or custom_llm_provider == "vertex_ai" - or custom_llm_provider == "watsonx" ): # currently implemented aiohttp calls for just azure and openai, soon all. # Await normally init_response = await loop.run_in_executor(None, func_with_context) @@ -3034,14 +3029,13 @@ def embedding( aembedding=aembedding, ) elif custom_llm_provider == "watsonx": - response = watsonxai.embedding( + response = watsonx.IBMWatsonXAI().embedding( model=model, input=input, encoding=encoding, logging_obj=logging, optional_params=optional_params, model_response=EmbeddingResponse(), - aembedding=aembedding, ) else: args = locals() diff --git a/litellm/utils.py b/litellm/utils.py index d1af1b44a..c03d4e2bc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10567,18 +10567,6 @@ class CustomStreamWrapper: elif self.custom_llm_provider == "watsonx": response_obj = self.handle_watsonx_stream(chunk) completion_obj["content"] = response_obj["text"] - print_verbose(f"completion obj content: {completion_obj['content']}") - if getattr(model_response, "usage", None) is None: - model_response.usage = Usage() - if response_obj.get("prompt_tokens") is not None: - prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0) - model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"]) - if response_obj.get("completion_tokens") is not None: - model_response.usage.completion_tokens = response_obj["completion_tokens"] - model_response.usage.total_tokens = ( - getattr(model_response.usage, "prompt_tokens", 0) - + getattr(model_response.usage, "completion_tokens", 0) - ) if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "text-completion-openai": @@ -10983,7 +10971,6 @@ class CustomStreamWrapper: or self.custom_llm_provider == "sagemaker" or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "cached_response" - or self.custom_llm_provider == "watsonx" or self.custom_llm_provider in litellm.openai_compatible_endpoints ): async for chunk in self.completion_stream: From dfd6361310bfc30b370a6e7a13699ae481e04403 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 9 May 2024 07:59:37 -0700 Subject: [PATCH 07/34] fix completion vs acompletion params --- litellm/main.py | 1 + litellm/tests/test_acompletion.py | 23 ++++++++++++++++++----- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index d6d276653..186b87060 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -207,6 +207,7 @@ async def acompletion( api_version: Optional[str] = None, api_key: Optional[str] = None, model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + extra_headers: Optional[dict] = None, # Optional liteLLM function params **kwargs, ): diff --git a/litellm/tests/test_acompletion.py b/litellm/tests/test_acompletion.py index e5c09b9b7..b83e34653 100644 --- a/litellm/tests/test_acompletion.py +++ b/litellm/tests/test_acompletion.py @@ -1,5 +1,6 @@ import pytest from litellm import acompletion +from litellm import completion def test_acompletion_params(): @@ -7,17 +8,29 @@ def test_acompletion_params(): from litellm.types.completion import CompletionRequest acompletion_params_odict = inspect.signature(acompletion).parameters - acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()} - completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()} + completion_params_dict = inspect.signature(completion).parameters - # remove kwargs - acompletion_params.pop("kwargs", None) + acompletion_params = { + name: param.annotation for name, param in acompletion_params_odict.items() + } + completion_params = { + name: param.annotation for name, param in completion_params_dict.items() + } keys_acompletion = set(acompletion_params.keys()) keys_completion = set(completion_params.keys()) + print(keys_acompletion) + print("\n\n\n") + print(keys_completion) + + print("diff=", keys_completion - keys_acompletion) + # Assert that the parameters are the same if keys_acompletion != keys_completion: - pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.") + pytest.fail( + "The parameters of the litellm.acompletion function and litellm.completion are not the same." + ) + # test_acompletion_params() From 4cfd9885295f48854425b20584673b35ce9cfbfe Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 08:01:17 -0700 Subject: [PATCH 08/34] fix(get_api_base): fix get_api_base to handle model with alias --- litellm/tests/test_model_alias_map.py | 10 ++++++++-- litellm/utils.py | 2 ++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/litellm/tests/test_model_alias_map.py b/litellm/tests/test_model_alias_map.py index 1501f49e4..31a7d34b8 100644 --- a/litellm/tests/test_model_alias_map.py +++ b/litellm/tests/test_model_alias_map.py @@ -16,7 +16,7 @@ litellm.set_verbose = True model_alias_map = {"good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf"} -def test_model_alias_map(): +def test_model_alias_map(caplog): try: litellm.model_alias_map = model_alias_map response = completion( @@ -27,9 +27,15 @@ def test_model_alias_map(): max_tokens=10, ) print(response.model) + + captured_logs = [rec.levelname for rec in caplog.records] + + for log in captured_logs: + assert "ERROR" not in log + assert "Llama-2-7b-chat-hf" in response.model except Exception as e: pytest.fail(f"Error occurred: {e}") -test_model_alias_map() +# test_model_alias_map() diff --git a/litellm/utils.py b/litellm/utils.py index c03d4e2bc..5725e4992 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5934,6 +5934,8 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]: if _optional_params.api_base is not None: return _optional_params.api_base + if litellm.model_alias_map and model in litellm.model_alias_map: + model = litellm.model_alias_map[model] try: model, custom_llm_provider, dynamic_api_key, dynamic_api_base = ( get_llm_provider( From 4d5b4a5293ff56473f44c856acaa98af171abd62 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 9 May 2024 08:35:35 -0700 Subject: [PATCH 09/34] add stream_options to text_completion --- litellm/main.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/litellm/main.py b/litellm/main.py index 186b87060..60e71a412 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3195,6 +3195,7 @@ def text_completion( Union[str, List[str]] ] = None, # Optional: Sequences where the API will stop generating further tokens. stream: Optional[bool] = None, # Optional: Whether to stream back partial progress. + stream_options: Optional[dict] = None, suffix: Optional[ str ] = None, # Optional: The suffix that comes after a completion of inserted text. @@ -3272,6 +3273,8 @@ def text_completion( optional_params["stop"] = stop if stream is not None: optional_params["stream"] = stream + if stream_options is not None: + optional_params["stream_options"] = stream_options if suffix is not None: optional_params["suffix"] = suffix if temperature is not None: @@ -3382,7 +3385,9 @@ def text_completion( if kwargs.get("acompletion", False) == True: return response if stream == True or kwargs.get("stream", False) == True: - response = TextCompletionStreamWrapper(completion_stream=response, model=model) + response = TextCompletionStreamWrapper( + completion_stream=response, model=model, stream_options=stream_options + ) return response transformed_logprobs = None # only supported for TGI models From 66053f14ae9d67ceb49b96055b0f8dd8831f9f9a Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 9 May 2024 08:37:40 -0700 Subject: [PATCH 10/34] stream_options for text-completionopenai --- litellm/llms/openai.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index d542cbe07..674cc86a2 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1205,6 +1205,7 @@ class OpenAITextCompletion(BaseLLM): model=model, custom_llm_provider="text-completion-openai", logging_obj=logging_obj, + stream_options=data.get("stream_options", None), ) for chunk in streamwrapper: @@ -1243,6 +1244,7 @@ class OpenAITextCompletion(BaseLLM): model=model, custom_llm_provider="text-completion-openai", logging_obj=logging_obj, + stream_options=data.get("stream_options", None), ) async for transformed_chunk in streamwrapper: From a29fcc057b5f00a7eb662a12e4e60d9560aba2a4 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 9 May 2024 08:41:31 -0700 Subject: [PATCH 11/34] test - stream_options on OpenAI text_completion --- litellm/tests/test_streaming.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 7d639d7a3..93d7567eb 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1534,6 +1534,39 @@ def test_openai_stream_options_call(): assert all(chunk.usage is None for chunk in chunks[:-1]) +def test_openai_stream_options_call_text_completion(): + litellm.set_verbose = False + response = litellm.text_completion( + model="gpt-3.5-turbo-instruct", + prompt="say GM - we're going to make it ", + stream=True, + stream_options={"include_usage": True}, + max_tokens=10, + ) + usage = None + chunks = [] + for chunk in response: + print("chunk: ", chunk) + chunks.append(chunk) + + last_chunk = chunks[-1] + print("last chunk: ", last_chunk) + + """ + Assert that: + - Last Chunk includes Usage + - All chunks prior to last chunk have usage=None + """ + + assert last_chunk.usage is not None + assert last_chunk.usage.total_tokens > 0 + assert last_chunk.usage.prompt_tokens > 0 + assert last_chunk.usage.completion_tokens > 0 + + # assert all non last chunks have usage=None + assert all(chunk.usage is None for chunk in chunks[:-1]) + + def test_openai_text_completion_call(): try: litellm.set_verbose = True From e0b1eff1eb67d57a56be0bfa87cf7b636c7c57a4 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 9 May 2024 08:42:25 -0700 Subject: [PATCH 12/34] feat - support stream_options for text completion --- litellm/utils.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/litellm/utils.py b/litellm/utils.py index 64a644f15..1932b8af6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10026,16 +10026,19 @@ class CustomStreamWrapper: text = "" is_finished = False finish_reason = None + usage = None choices = getattr(chunk, "choices", []) if len(choices) > 0: text = choices[0].text if choices[0].finish_reason is not None: is_finished = True finish_reason = choices[0].finish_reason + usage = getattr(chunk, "usage", None) return { "text": text, "is_finished": is_finished, "finish_reason": finish_reason, + "usage": usage, } except Exception as e: @@ -10565,6 +10568,11 @@ class CustomStreamWrapper: print_verbose(f"completion obj content: {completion_obj['content']}") if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] + if ( + self.stream_options + and self.stream_options.get("include_usage", False) == True + ): + model_response.usage = response_obj["usage"] elif self.custom_llm_provider == "azure_text": response_obj = self.handle_azure_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] @@ -11094,9 +11102,10 @@ class CustomStreamWrapper: class TextCompletionStreamWrapper: - def __init__(self, completion_stream, model): + def __init__(self, completion_stream, model, stream_options): self.completion_stream = completion_stream self.model = model + self.stream_options = stream_options def __iter__(self): return self @@ -11120,6 +11129,14 @@ class TextCompletionStreamWrapper: text_choices["index"] = chunk["choices"][0]["index"] text_choices["finish_reason"] = chunk["choices"][0]["finish_reason"] response["choices"] = [text_choices] + + # only pass usage when stream_options["include_usage"] is True + if ( + self.stream_options + and self.stream_options.get("include_usage", False) == True + ): + response["usage"] = chunk.get("usage", None) + return response except Exception as e: raise Exception( From 6634ea37e9f8b9a20d2fef7cbc3ae0a7b55d438f Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 9 May 2024 09:54:44 -0700 Subject: [PATCH 13/34] fix TextCompletionStreamWrapper --- litellm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/utils.py b/litellm/utils.py index 1932b8af6..09f1bb2e2 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -11102,7 +11102,7 @@ class CustomStreamWrapper: class TextCompletionStreamWrapper: - def __init__(self, completion_stream, model, stream_options): + def __init__(self, completion_stream, model, stream_options: Optional[dict] = None): self.completion_stream = completion_stream self.model = model self.stream_options = stream_options From 84b2af8e6c91719c36ebf2601af6193b75e51a68 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 9 May 2024 12:27:44 -0700 Subject: [PATCH 14/34] fix show docker run on repo --- .github/workflows/interpret_load_test.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.github/workflows/interpret_load_test.py b/.github/workflows/interpret_load_test.py index 9d95c768f..de52f47ad 100644 --- a/.github/workflows/interpret_load_test.py +++ b/.github/workflows/interpret_load_test.py @@ -64,6 +64,11 @@ if __name__ == "__main__": ) # Replace with your repository's username and name latest_release = repo.get_latest_release() print("got latest release: ", latest_release) + print(latest_release.title) + print(latest_release.tag_name) + + release_version = latest_release.title + print("latest release body: ", latest_release.body) print("markdown table: ", markdown_table) @@ -74,6 +79,18 @@ if __name__ == "__main__": start_index = latest_release.body.find("Load Test LiteLLM Proxy Results") existing_release_body = latest_release.body[:start_index] + docker_run_command = f""" + ## Docker Run LiteLLM Proxy + + ``` + docker run \\ + -e STORE_MODEL_IN_DB=True \\ + -p 4000:4000 \\ + ghcr.io/berriai/litellm:main-{release_version} + ``` + """ + print("docker run command: ", docker_run_command) + new_release_body = ( existing_release_body + "\n\n" From 30b7e9f776475b2d84df9e5deb12d6ebaa44f349 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 9 May 2024 12:28:14 -0700 Subject: [PATCH 15/34] temp fix laod test --- .github/workflows/load_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/load_test.yml b/.github/workflows/load_test.yml index ddf613fa6..9af700d24 100644 --- a/.github/workflows/load_test.yml +++ b/.github/workflows/load_test.yml @@ -30,7 +30,7 @@ jobs: URL: "https://litellm-database-docker-build-production.up.railway.app/" USERS: "100" RATE: "10" - RUNTIME: "300s" + RUNTIME: "3s" - name: Process Load Test Stats run: | echo "Current working directory: $PWD" From 50b4167a27822069a22c94d28e1b73da1b5bf44a Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 9 May 2024 12:30:35 -0700 Subject: [PATCH 16/34] fix interpret load test --- .github/workflows/interpret_load_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/interpret_load_test.py b/.github/workflows/interpret_load_test.py index de52f47ad..b61f1ce67 100644 --- a/.github/workflows/interpret_load_test.py +++ b/.github/workflows/interpret_load_test.py @@ -80,6 +80,7 @@ if __name__ == "__main__": existing_release_body = latest_release.body[:start_index] docker_run_command = f""" + \n\n ## Docker Run LiteLLM Proxy ``` @@ -93,6 +94,7 @@ if __name__ == "__main__": new_release_body = ( existing_release_body + + docker_run_command + "\n\n" + "### Don't want to maintain your internal proxy? get in touch 🎉" + "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat" From c8662234c7b6c3d26b696d9284ec83dea1dfbe87 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 9 May 2024 12:38:38 -0700 Subject: [PATCH 17/34] fix docker run command on release notes --- .github/workflows/interpret_load_test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/interpret_load_test.py b/.github/workflows/interpret_load_test.py index b61f1ce67..b1a28e069 100644 --- a/.github/workflows/interpret_load_test.py +++ b/.github/workflows/interpret_load_test.py @@ -80,15 +80,15 @@ if __name__ == "__main__": existing_release_body = latest_release.body[:start_index] docker_run_command = f""" - \n\n - ## Docker Run LiteLLM Proxy +\n\n +## Docker Run LiteLLM Proxy - ``` - docker run \\ - -e STORE_MODEL_IN_DB=True \\ - -p 4000:4000 \\ - ghcr.io/berriai/litellm:main-{release_version} - ``` +``` +docker run \\ +-e STORE_MODEL_IN_DB=True \\ +-p 4000:4000 \\ +ghcr.io/berriai/litellm:main-{release_version} +``` """ print("docker run command: ", docker_run_command) From bf99311f5c3dcae825dee6763af90f40a9b18c24 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 9 May 2024 12:50:24 -0700 Subject: [PATCH 18/34] fix load test length --- .github/workflows/load_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/load_test.yml b/.github/workflows/load_test.yml index 9af700d24..ddf613fa6 100644 --- a/.github/workflows/load_test.yml +++ b/.github/workflows/load_test.yml @@ -30,7 +30,7 @@ jobs: URL: "https://litellm-database-docker-build-production.up.railway.app/" USERS: "100" RATE: "10" - RUNTIME: "3s" + RUNTIME: "300s" - name: Process Load Test Stats run: | echo "Current working directory: $PWD" From e3f25a4a1fbaa5de962845286863009e1ab00ced Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 13:05:41 -0700 Subject: [PATCH 19/34] fix(auth_checks.py): fix 'get_end_user_object' await cache get --- litellm/proxy/auth/auth_checks.py | 6 ++---- litellm/proxy/proxy_server.py | 7 +++++-- litellm/tests/test_key_generate_prisma.py | 9 ++++++++- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 920de3cc8..62e5eba01 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -206,11 +206,9 @@ async def get_end_user_object( if end_user_id is None: return None - + _key = "end_user_id:{}".format(end_user_id) # check if in cache - cached_user_obj = user_api_key_cache.async_get_cache( - key="end_user_id:{}".format(end_user_id) - ) + cached_user_obj = await user_api_key_cache.async_get_cache(key=_key) if cached_user_obj is not None: if isinstance(cached_user_obj, dict): return LiteLLM_EndUserTable(**cached_user_obj) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f50fddf7f..1ba29dcc6 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1834,6 +1834,9 @@ async def update_cache( ) async def _update_end_user_cache(): + if end_user_id is None or response_cost is None: + return + _id = "end_user_id:{}".format(end_user_id) try: # Fetch the existing cost for the given user @@ -1846,7 +1849,7 @@ async def update_cache( if litellm.max_end_user_budget is not None: max_end_user_budget = litellm.max_end_user_budget existing_spend_obj = LiteLLM_EndUserTable( - user_id=_id, + user_id=end_user_id, spend=0, blocked=False, litellm_budget_table=LiteLLM_BudgetTable( @@ -1874,7 +1877,7 @@ async def update_cache( existing_spend_obj.spend = new_spend user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json()) except Exception as e: - verbose_proxy_logger.debug( + verbose_proxy_logger.error( f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}" ) diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 08618c988..e6f2437e7 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -418,9 +418,16 @@ def test_call_with_user_over_budget(prisma_client): print(vars(e)) +def test_end_user_cache_write_unit_test(): + """ + assert end user object is being written to cache as expected + """ + pass + + def test_call_with_end_user_over_budget(prisma_client): # Test if a user passed to /chat/completions is tracked & fails when they cross their budget - # we only check this when litellm.max_user_budget is set + # we only check this when litellm.max_end_user_budget is set import random setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) From 927d36148f5c82ba3f85cd79f81cf4c211bd6d7b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 13:21:00 -0700 Subject: [PATCH 20/34] feat(proxy_server.py): expose new `/team/list` endpoint Closes https://github.com/BerriAI/litellm/issues/3523 --- litellm/proxy/proxy_server.py | 37 +++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1ba29dcc6..0456881cd 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -7313,6 +7313,43 @@ async def unblock_team( return record +@router.get( + "/team/list", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +async def list_team( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [Admin-only] List all available teams + + ``` + curl --location --request GET 'http://0.0.0.0:4000/team/list' \ + --header 'Authorization: Bearer sk-1234' + ``` + """ + global prisma_client + + if user_api_key_dict.user_role != "proxy_admin": + raise HTTPException( + status_code=401, + detail={ + "error": "Admin-only endpoint. Your user role={}".format( + user_api_key_dict.user_role + ) + }, + ) + + if prisma_client is None: + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + response = await prisma_client.db.litellm_teamtable.find_many() + + return response + + #### ORGANIZATION MANAGEMENT #### From c4295e16672ca612b2bc665e5fef71953de09e51 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 13:54:24 -0700 Subject: [PATCH 21/34] test(test_least_busy_routing.py): avoid deployments with low rate limits --- litellm/tests/test_least_busy_routing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/litellm/tests/test_least_busy_routing.py b/litellm/tests/test_least_busy_routing.py index cb9d59e75..203a03c63 100644 --- a/litellm/tests/test_least_busy_routing.py +++ b/litellm/tests/test_least_busy_routing.py @@ -150,9 +150,9 @@ async def test_router_atext_completion_streaming(): { "model_name": "azure-model", "litellm_params": { - "model": "azure/gpt-35-turbo", - "api_key": "os.environ/AZURE_EUROPE_API_KEY", - "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", "rpm": 6, }, "model_info": {"id": 2}, @@ -160,9 +160,9 @@ async def test_router_atext_completion_streaming(): { "model_name": "azure-model", "litellm_params": { - "model": "azure/gpt-35-turbo", - "api_key": "os.environ/AZURE_CANADA_API_KEY", - "api_base": "https://my-endpoint-canada-berri992.openai.azure.com", + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", "rpm": 6, }, "model_info": {"id": 3}, @@ -193,7 +193,7 @@ async def test_router_atext_completion_streaming(): ## check if calls equally distributed cache_dict = router.cache.get_cache(key=cache_key) for k, v in cache_dict.items(): - assert v == 1 + assert v == 1, f"Failed. K={k} called v={v} times, cache_dict={cache_dict}" # asyncio.run(test_router_atext_completion_streaming()) From acb615957dad014c795cf360d58fbe3a9f4fa096 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 13:58:45 -0700 Subject: [PATCH 22/34] fix(utils.py): change error log to be debug --- litellm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/utils.py b/litellm/utils.py index 20716f43d..6da296038 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5938,7 +5938,7 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]: model=model, **optional_params ) # convert to pydantic object except Exception as e: - verbose_logger.error("Error occurred in getting api base - {}".format(str(e))) + verbose_logger.debug("Error occurred in getting api base - {}".format(str(e))) return None # get llm provider From 5c6a382d3b50566e38769521405e89d7d8439d97 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 15:41:33 -0700 Subject: [PATCH 23/34] refactor(main.py): trigger new build --- litellm/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/main.py b/litellm/main.py index 99e556bfa..5ab3fd7c4 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -14,6 +14,7 @@ import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy import httpx import litellm + from ._logging import verbose_logger from litellm import ( # type: ignore client, From 43b2050cc25006f607307ce3efb4b0089ab98f8d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 15:41:40 -0700 Subject: [PATCH 24/34] =?UTF-8?q?bump:=20version=201.36.4=20=E2=86=92=201.?= =?UTF-8?q?37.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 399b188c2..a9854cf69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.36.4" +version = "1.37.0" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.36.4" +version = "1.37.0" version_files = [ "pyproject.toml:^version" ] From c42f1ce2c65375bde058b1f2b399bbd88ea09426 Mon Sep 17 00:00:00 2001 From: Nick Wong Date: Thu, 9 May 2024 16:13:26 -0700 Subject: [PATCH 25/34] removed extra default dict return, which causes error if user_role is a string --- litellm/proxy/proxy_server.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 4bb8dee7f..86c22186d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1065,9 +1065,7 @@ async def user_api_key_auth( user_id_information, list ): _user = user_id_information[0] - user_role = _user.get("user_role", {}).get( - "user_role", "unknown" - ) + user_role = _user.get("user_role", "unknown") user_id = _user.get("user_id", "unknown") raise Exception( f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}" From 186c0ec77bec2900a184d8628ff1457dd25b114e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 16:39:43 -0700 Subject: [PATCH 26/34] feat(predibase.py): add support for predibase provider Closes https://github.com/BerriAI/litellm/issues/1253 --- litellm/__init__.py | 4 + litellm/llms/huggingface_restapi.py | 2 +- litellm/llms/predibase.py | 417 ++++++++++++++++++++++++++++ litellm/main.py | 54 ++++ litellm/tests/test_completion.py | 23 ++ litellm/utils.py | 2 +- 6 files changed, 500 insertions(+), 2 deletions(-) create mode 100644 litellm/llms/predibase.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 4f72504f6..ccf3657fe 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -71,9 +71,11 @@ maritalk_key: Optional[str] = None ai21_key: Optional[str] = None ollama_key: Optional[str] = None openrouter_key: Optional[str] = None +predibase_key: Optional[str] = None huggingface_key: Optional[str] = None vertex_project: Optional[str] = None vertex_location: Optional[str] = None +predibase_tenant_id: Optional[str] = None togetherai_api_key: Optional[str] = None cloudflare_api_key: Optional[str] = None baseten_key: Optional[str] = None @@ -532,6 +534,7 @@ provider_list: List = [ "xinference", "fireworks_ai", "watsonx", + "predibase", "custom", # custom apis ] @@ -644,6 +647,7 @@ from .utils import ( ) from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig +from .llms.predibase import PredibaseConfig from .llms.anthropic_text import AnthropicTextConfig from .llms.replicate import ReplicateConfig from .llms.cohere import CohereConfig diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 293773289..b250f3013 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -322,9 +322,9 @@ class Huggingface(BaseLLM): encoding, api_key, logging_obj, + optional_params: dict, custom_prompt_dict={}, acompletion: bool = False, - optional_params=None, litellm_params=None, logger_fn=None, ): diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py new file mode 100644 index 000000000..728a98b04 --- /dev/null +++ b/litellm/llms/predibase.py @@ -0,0 +1,417 @@ +# What is this? +## Controller file for Predibase Integration - https://predibase.com/ + + +import os, types +import json +from enum import Enum +import requests, copy # type: ignore +import time +from typing import Callable, Optional, List, Literal, Union +from litellm.utils import ( + ModelResponse, + Usage, + map_finish_reason, + CustomStreamWrapper, + Message, + Choices, +) +import litellm +from .prompt_templates.factory import prompt_factory, custom_prompt +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from .base import BaseLLM +import httpx # type: ignore + + +class PredibaseError(Exception): + def __init__( + self, + status_code, + message, + request: Optional[httpx.Request] = None, + response: Optional[httpx.Response] = None, + ): + self.status_code = status_code + self.message = message + if request is not None: + self.request = request + else: + self.request = httpx.Request( + method="POST", + url="https://docs.predibase.com/user-guide/inference/rest_api", + ) + if response is not None: + self.response = response + else: + self.response = httpx.Response( + status_code=status_code, request=self.request + ) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class PredibaseConfig: + """ + Reference: https://docs.predibase.com/user-guide/inference/rest_api + + """ + + adapter_id: Optional[str] = None + adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None + best_of: Optional[int] = None + decoder_input_details: bool = True # on by default - get the finish reason + details: Optional[bool] = True # enables returning logprobs + best of + max_new_tokens: int = ( + 256 # openai default - requests hang if max_new_tokens not given + ) + repetition_penalty: Optional[float] = None + return_full_text: Optional[bool] = ( + False # by default don't return the input as part of the output + ) + seed: Optional[int] = None + stop: Optional[List[str]] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[int] = None + truncate: Optional[int] = None + typical_p: Optional[float] = None + watermark: Optional[bool] = None + + def __init__( + self, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + details: Optional[bool] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[int] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] + + +class PredibaseChatCompletion(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict: + if api_key is None: + raise ValueError( + "Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params" + ) + headers = { + "content-type": "application/json", + "Authorization": "Bearer {}".format(api_key), + } + if user_headers is not None and isinstance(user_headers, dict): + headers = {**headers, **user_headers} + return headers + + def output_parser(self, generated_text: str): + """ + Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. + + Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763 + """ + chat_template_tokens = [ + "<|assistant|>", + "<|system|>", + "<|user|>", + "", + "", + ] + for token in chat_template_tokens: + if generated_text.strip().startswith(token): + generated_text = generated_text.replace(token, "", 1) + if generated_text.endswith(token): + generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] + return generated_text + + def process_response( + self, + model: str, + response: requests.Response, + model_response: ModelResponse, + stream: bool, + logging_obj: litellm.utils.Logging, + optional_params: dict, + api_key: str, + data: dict, + messages: list, + print_verbose, + encoding, + ) -> ModelResponse: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + try: + completion_response = response.json() + except: + raise PredibaseError( + message=response.text, status_code=response.status_code + ) + if "error" in completion_response: + raise PredibaseError( + message=str(completion_response["error"]), + status_code=response.status_code, + ) + else: + if ( + not isinstance(completion_response, dict) + or "generated_text" not in completion_response + ): + raise PredibaseError( + status_code=422, + message=f"response is not in expected format - {completion_response}", + ) + + if len(completion_response["generated_text"]) > 0: + model_response["choices"][0]["message"]["content"] = self.output_parser( + completion_response["generated_text"] + ) + ## GETTING LOGPROBS + FINISH REASON + if ( + "details" in completion_response + and "tokens" in completion_response["details"] + ): + model_response.choices[0].finish_reason = completion_response[ + "details" + ]["finish_reason"] + sum_logprob = 0 + for token in completion_response[0]["details"]["tokens"]: + if token["logprob"] != None: + sum_logprob += token["logprob"] + model_response["choices"][0][ + "message" + ]._logprob = ( + sum_logprob # [TODO] move this to using the actual logprobs + ) + if "best_of" in optional_params and optional_params["best_of"] > 1: + if ( + "details" in completion_response[0] + and "best_of_sequences" in completion_response[0]["details"] + ): + choices_list = [] + for idx, item in enumerate( + completion_response[0]["details"]["best_of_sequences"] + ): + sum_logprob = 0 + for token in item["tokens"]: + if token["logprob"] != None: + sum_logprob += token["logprob"] + if len(item["generated_text"]) > 0: + message_obj = Message( + content=self.output_parser(item["generated_text"]), + logprobs=sum_logprob, + ) + else: + message_obj = Message(content=None) + choice_obj = Choices( + finish_reason=item["finish_reason"], + index=idx + 1, + message=message_obj, + ) + choices_list.append(choice_obj) + model_response["choices"].extend(choices_list) + + ## CALCULATING USAGE + prompt_tokens = 0 + try: + prompt_tokens = len( + encoding.encode(model_response["choices"][0]["message"]["content"]) + ) ##[TODO] use a model-specific tokenizer here + except: + # this should remain non blocking we should not block a response returning if calculating usage fails + pass + output_text = model_response["choices"][0]["message"].get("content", "") + if output_text is not None and len(output_text) > 0: + completion_tokens = 0 + try: + completion_tokens = len( + encoding.encode( + model_response["choices"][0]["message"].get("content", "") + ) + ) ##[TODO] use a model-specific tokenizer + except: + # this should remain non blocking we should not block a response returning if calculating usage fails + pass + else: + completion_tokens = 0 + + total_tokens = prompt_tokens + completion_tokens + + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + model_response.usage = usage # type: ignore + return model_response + + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key: str, + logging_obj, + optional_params: dict, + tenant_id: str, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers: dict = {}, + ): + headers = self.validate_environment(api_key, headers) + completion_url = "" + input_text = "" + base_url = "https://serving.app.predibase.com" + if "https" in model: + completion_url = model + elif api_base: + base_url = api_base + elif "PREDIBASE_API_BASE" in os.environ: + base_url = os.getenv("PREDIBASE_API_BASE", "") + + completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}/generate" + + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) + else: + prompt = prompt_factory(model=model, messages=messages) + + ## Load Config + config = litellm.PredibaseConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + data = { + "inputs": prompt, + "parameters": optional_params, + } + if optional_params.get("stream") and optional_params["stream"] == True: + data["stream"] = True + input_text = prompt + ## LOGGING + logging_obj.pre_call( + input=input_text, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": completion_url, + "acompletion": acompletion, + }, + ) + ## COMPLETION CALL + if acompletion is True: + ### ASYNC STREAMING + if optional_params.get("stream", False): + return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout) # type: ignore + else: + ### ASYNC COMPLETION + return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore + + ### SYNC STREAMING + if "stream" in optional_params and optional_params["stream"] == True: + response = requests.post( + completion_url, + headers=headers, + data=json.dumps(data), + stream=optional_params["stream"], + ) + return response.iter_lines() + ### SYNC COMPLETION + else: + payload = json.dumps( + { + "inputs": "What is your name?", + "parameters": {"max_new_tokens": 20, "temperature": 0.1}, + } + # data + ) + response = requests.post( + url="https://serving.app.predibase.com/c4768f95/deployments/v2/llms/llama-3-8b-instruct/generate", + headers=headers, + data=payload, + ) + + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=optional_params.get("stream", False), + logging_obj=logging_obj, # type: ignore + optional_params=optional_params, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + encoding=encoding, + ) + + async def async_completion(self): + pass + + async def async_streaming(self): + pass + + def streaming(self): + pass + + def embedding(self, *args, **kwargs): + pass diff --git a/litellm/main.py b/litellm/main.py index 5ab3fd7c4..f634fd16d 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -74,6 +74,7 @@ from .llms.azure_text import AzureTextCompletion from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic_text import AnthropicTextCompletion from .llms.huggingface_restapi import Huggingface +from .llms.predibase import PredibaseChatCompletion from .llms.prompt_templates.factory import ( prompt_factory, custom_prompt, @@ -110,6 +111,7 @@ anthropic_text_completions = AnthropicTextCompletion() azure_chat_completions = AzureChatCompletion() azure_text_completions = AzureTextCompletion() huggingface = Huggingface() +predibase_chat_completions = PredibaseChatCompletion() ####### COMPLETION ENDPOINTS ################ @@ -1785,6 +1787,58 @@ def completion( ) return response response = model_response + elif custom_llm_provider == "predibase": + tenant_id = ( + optional_params.pop("tenant_id", None) + or optional_params.pop("predibase_tenant_id", None) + or litellm.predibase_tenant_id + or get_secret("PREDIBASE_TENANT_ID") + ) + + api_base = ( + optional_params.pop("api_base", None) + or optional_params.pop("base_url", None) + or litellm.api_base + or get_secret("PREDIBASE_API_BASE") + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.predibase_key + or get_secret("PREDIBASE_API_KEY") + ) + + model_response = predibase_chat_completions.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + acompletion=acompletion, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + api_key=api_key, + tenant_id=tenant_id, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion == False + ): + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="predibase", + logging_obj=logging, + ) + return response + response = model_response elif custom_llm_provider == "ai21": custom_llm_provider = "ai21" ai21_key = ( diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 32b65faea..7f0977b15 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -85,6 +85,29 @@ def test_completion_azure_command_r(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.skip(reason="local test") +def test_completion_predibase(): + try: + litellm.set_verbose = True + + response = completion( + model="predibase/llama-3-8b-instruct", + tenant_id="c4768f95", + api_base="https://serving.app.predibase.com", + api_key=os.getenv("PREDIBASE_API_KEY"), + messages=[{"role": "user", "content": "What is the meaning of life?"}], + ) + + print(response) + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +# test_completion_predibase() + + def test_completion_claude(): litellm.set_verbose = True litellm.cache = None diff --git a/litellm/utils.py b/litellm/utils.py index 6da296038..7ccb5e8ff 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -369,7 +369,7 @@ class ChatCompletionMessageToolCall(OpenAIObject): class Message(OpenAIObject): def __init__( self, - content="default", + content: Optional[str] = "default", role="assistant", logprobs=None, function_call=None, From d7189c21fdd57e8cf445c7ef82c9d73df3dfaf7e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 17:41:27 -0700 Subject: [PATCH 27/34] feat(predibase.py): support async_completion + streaming (sync + async) finishes up pr --- litellm/llms/predibase.py | 166 +- litellm/main.py | 7 +- litellm/tests/log.txt | 6981 +++++++++++++++++++++++++++++- litellm/tests/test_completion.py | 33 +- litellm/tests/test_streaming.py | 67 +- litellm/utils.py | 50 + 6 files changed, 7196 insertions(+), 108 deletions(-) diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index 728a98b04..f3935984d 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -60,8 +60,8 @@ class PredibaseConfig: adapter_id: Optional[str] = None adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None best_of: Optional[int] = None - decoder_input_details: bool = True # on by default - get the finish reason - details: Optional[bool] = True # enables returning logprobs + best of + decoder_input_details: Optional[bool] = None + details: bool = True # enables returning logprobs + best of max_new_tokens: int = ( 256 # openai default - requests hang if max_new_tokens not given ) @@ -124,6 +124,9 @@ class PredibaseConfig: class PredibaseChatCompletion(BaseLLM): def __init__(self) -> None: + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=litellm.request_timeout, connect=5.0) + ) super().__init__() def validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict: @@ -162,7 +165,7 @@ class PredibaseChatCompletion(BaseLLM): def process_response( self, model: str, - response: requests.Response, + response: Union[requests.Response, httpx.Response], model_response: ModelResponse, stream: bool, logging_obj: litellm.utils.Logging, @@ -216,7 +219,7 @@ class PredibaseChatCompletion(BaseLLM): "details" ]["finish_reason"] sum_logprob = 0 - for token in completion_response[0]["details"]["tokens"]: + for token in completion_response["details"]["tokens"]: if token["logprob"] != None: sum_logprob += token["logprob"] model_response["choices"][0][ @@ -226,12 +229,12 @@ class PredibaseChatCompletion(BaseLLM): ) if "best_of" in optional_params and optional_params["best_of"] > 1: if ( - "details" in completion_response[0] - and "best_of_sequences" in completion_response[0]["details"] + "details" in completion_response + and "best_of_sequences" in completion_response["details"] ): choices_list = [] for idx, item in enumerate( - completion_response[0]["details"]["best_of_sequences"] + completion_response["details"]["best_of_sequences"] ): sum_logprob = 0 for token in item["tokens"]: @@ -305,7 +308,7 @@ class PredibaseChatCompletion(BaseLLM): litellm_params=None, logger_fn=None, headers: dict = {}, - ): + ) -> Union[ModelResponse, CustomStreamWrapper]: headers = self.validate_environment(api_key, headers) completion_url = "" input_text = "" @@ -317,7 +320,12 @@ class PredibaseChatCompletion(BaseLLM): elif "PREDIBASE_API_BASE" in os.environ: base_url = os.getenv("PREDIBASE_API_BASE", "") - completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}/generate" + completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}" + + if optional_params.get("stream", False) == True: + completion_url += "/generate_stream" + else: + completion_url += "/generate" if model in custom_prompt_dict: # check if the model has a registered custom prompt @@ -339,12 +347,12 @@ class PredibaseChatCompletion(BaseLLM): ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v + stream = optional_params.pop("stream", False) + data = { "inputs": prompt, "parameters": optional_params, } - if optional_params.get("stream") and optional_params["stream"] == True: - data["stream"] = True input_text = prompt ## LOGGING logging_obj.pre_call( @@ -360,34 +368,62 @@ class PredibaseChatCompletion(BaseLLM): ## COMPLETION CALL if acompletion is True: ### ASYNC STREAMING - if optional_params.get("stream", False): - return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout) # type: ignore + if stream == True: + return self.async_streaming( + model=model, + messages=messages, + data=data, + api_base=completion_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + ) # type: ignore else: ### ASYNC COMPLETION - return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore + return self.async_completion( + model=model, + messages=messages, + data=data, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=False, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + ) # type: ignore ### SYNC STREAMING - if "stream" in optional_params and optional_params["stream"] == True: + if stream == True: response = requests.post( completion_url, headers=headers, data=json.dumps(data), - stream=optional_params["stream"], + stream=stream, ) - return response.iter_lines() + response = CustomStreamWrapper( + response.iter_lines(), + model, + custom_llm_provider="predibase", + logging_obj=logging_obj, + ) + return response ### SYNC COMPLETION else: - payload = json.dumps( - { - "inputs": "What is your name?", - "parameters": {"max_new_tokens": 20, "temperature": 0.1}, - } - # data - ) response = requests.post( - url="https://serving.app.predibase.com/c4768f95/deployments/v2/llms/llama-3-8b-instruct/generate", + url=completion_url, headers=headers, - data=payload, + data=json.dumps(data), ) return self.process_response( @@ -404,14 +440,80 @@ class PredibaseChatCompletion(BaseLLM): encoding=encoding, ) - async def async_completion(self): - pass + async def async_completion( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + stream, + data: dict, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + ) -> ModelResponse: - async def async_streaming(self): - pass + response = await self.async_handler.post( + api_base, headers=headers, data=json.dumps(data) + ) + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) - def streaming(self): - pass + async def async_streaming( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + data: dict, + optional_params=None, + litellm_params=None, + logger_fn=None, + headers={}, + ) -> CustomStreamWrapper: + + data["stream"] = True + response = await self.async_handler.post( + url="https://serving.app.predibase.com/c4768f95/deployments/v2/llms/llama-3-8b-instruct/generate_stream", + headers=headers, + data=json.dumps(data), + stream=True, + ) + + if response.status_code != 200: + raise PredibaseError( + status_code=response.status_code, message=response.text + ) + + completion_stream = response.aiter_lines() + + streamwrapper = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="predibase", + logging_obj=logging_obj, + ) + return streamwrapper def embedding(self, *args, **kwargs): pass diff --git a/litellm/main.py b/litellm/main.py index f634fd16d..cec49d35f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -320,6 +320,7 @@ async def acompletion( or custom_llm_provider == "gemini" or custom_llm_provider == "sagemaker" or custom_llm_provider == "anthropic" + or custom_llm_provider == "predibase" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -1831,12 +1832,6 @@ def completion( and optional_params["stream"] == True and acompletion == False ): - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="predibase", - logging_obj=logging, - ) return response response = model_response elif custom_llm_provider == "ai21": diff --git a/litellm/tests/log.txt b/litellm/tests/log.txt index 2c76ac947..4d3027355 100644 --- a/litellm/tests/log.txt +++ b/litellm/tests/log.txt @@ -1,25 +1,6884 @@ ============================= test session starts ============================== -platform darwin -- Python 3.11.6, pytest-7.3.1, pluggy-1.3.0 +platform darwin -- Python 3.11.9, pytest-7.3.1, pluggy-1.3.0 rootdir: /Users/krrishdholakia/Documents/litellm/litellm/tests plugins: timeout-2.2.0, asyncio-0.23.2, anyio-3.7.1, xdist-3.3.1 asyncio: mode=Mode.STRICT -collected 1 item +collected 2 items -test_completion.py Chunks sorted -token_counter messages received: [{'role': 'user', 'content': 'what is the capital of congo?'}] -Token Counter - using generic token counter, for model=gemini-1.5-pro-latest -LiteLLM: Utils - Counting tokens for OpenAI model=gpt-3.5-turbo -.Token Counter - using generic token counter, for model=gemini-1.5-pro-latest -LiteLLM: Utils - Counting tokens for OpenAI model=gpt-3.5-turbo -Looking up model=gemini/gemini-1.5-pro-latest in model_cost_map -Success: model=gemini/gemini-1.5-pro-latest in model_cost_map -prompt_tokens=15; completion_tokens=1 -Returned custom cost for model=gemini/gemini-1.5-pro-latest - prompt_tokens_cost_usd_dollar: 0, completion_tokens_cost_usd_dollar: 0 -final cost: 0; prompt_tokens_cost_usd_dollar: 0; completion_tokens_cost_usd_dollar: 0 - [100%] +test_streaming.py .Token Counter - using hugging face token counter, for model=llama-3-8b-instruct +Looking up model=llama-3-8b-instruct in model_cost_map +F [100%] +=================================== FAILURES =================================== +__________________ test_completion_predibase_streaming[True] ___________________ + +model = 'llama-3-8b-instruct' +messages = [{'content': 'What is the meaning of life?', 'role': 'user'}] +timeout = 600.0, temperature = None, top_p = None, n = None, stream = True +stream_options = None, stop = None, max_tokens = None, presence_penalty = None +frequency_penalty = None, logit_bias = None, user = None, response_format = None +seed = None, tools = None, tool_choice = None, logprobs = None +top_logprobs = None, deployment_id = None, extra_headers = None +functions = None, function_call = None, base_url = None, api_version = None +api_key = 'pb_Qg9YbQo7UqqHdu0ozxN_aw', model_list = None +kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_id': 'cf0ea464-1b45-4473-8e55-6bf6809df7a7', 'litellm_logging_obj': , 'tenant_id': 'c4768f95'} +args = {'acompletion': False, 'api_base': None, 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'api_version': None, ...} +api_base = None, mock_response = None, force_timeout = 600, logger_fn = None +verbose = False, custom_llm_provider = 'predibase' + + @client + def completion( + model: str, + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + messages: List = [], + timeout: Optional[Union[float, str, httpx.Timeout]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream: Optional[bool] = None, + stream_options: Optional[dict] = None, + stop=None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[dict] = None, + user: Optional[str] = None, + # openai v1.0+ new params + response_format: Optional[dict] = None, + seed: Optional[int] = None, + tools: Optional[List] = None, + tool_choice: Optional[str] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + deployment_id=None, + extra_headers: Optional[dict] = None, + # soon to be deprecated params by OpenAI + functions: Optional[List] = None, + function_call: Optional[str] = None, + # set api_base, api_version, api_key + base_url: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + # Optional liteLLM function params + **kwargs, + ) -> Union[ModelResponse, CustomStreamWrapper]: + """ + Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) + Parameters: + model (str): The name of the language model to use for text completion. see all supported LLMs: https://docs.litellm.ai/docs/providers/ + messages (List): A list of message objects representing the conversation context (default is an empty list). + + OPTIONAL PARAMS + functions (List, optional): A list of functions to apply to the conversation messages (default is an empty list). + function_call (str, optional): The name of the function to call within the conversation (default is an empty string). + temperature (float, optional): The temperature parameter for controlling the randomness of the output (default is 1.0). + top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). + n (int, optional): The number of completions to generate (default is 1). + stream (bool, optional): If True, return a streaming response (default is False). + stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true. + stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. + max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). + presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. + frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. + logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. + user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. + logprobs (bool, optional): Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message + top_logprobs (int, optional): An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. + metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. + api_base (str, optional): Base URL for the API (default is None). + api_version (str, optional): API version (default is None). + api_key (str, optional): API key (default is None). + model_list (list, optional): List of api base, version, keys + extra_headers (dict, optional): Additional headers to include in the request. + + LITELLM Specific Params + mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None). + custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock" + max_retries (int, optional): The number of retries to attempt (default is 0). + Returns: + ModelResponse: A response object containing the generated completion and associated metadata. + + Note: + - This function is used to perform completions() using the specified language model. + - It supports various optional parameters for customizing the completion behavior. + - If 'mock_response' is provided, a mock completion response is returned for testing or debugging. + """ + ######### unpacking kwargs ##################### + args = locals() + api_base = kwargs.get("api_base", None) + mock_response = kwargs.get("mock_response", None) + force_timeout = kwargs.get("force_timeout", 600) ## deprecated + logger_fn = kwargs.get("logger_fn", None) + verbose = kwargs.get("verbose", False) + custom_llm_provider = kwargs.get("custom_llm_provider", None) + litellm_logging_obj = kwargs.get("litellm_logging_obj", None) + id = kwargs.get("id", None) + metadata = kwargs.get("metadata", None) + model_info = kwargs.get("model_info", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + fallbacks = kwargs.get("fallbacks", None) + headers = kwargs.get("headers", None) + num_retries = kwargs.get("num_retries", None) ## deprecated + max_retries = kwargs.get("max_retries", None) + context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) + organization = kwargs.get("organization", None) + ### CUSTOM MODEL COST ### + input_cost_per_token = kwargs.get("input_cost_per_token", None) + output_cost_per_token = kwargs.get("output_cost_per_token", None) + input_cost_per_second = kwargs.get("input_cost_per_second", None) + output_cost_per_second = kwargs.get("output_cost_per_second", None) + ### CUSTOM PROMPT TEMPLATE ### + initial_prompt_value = kwargs.get("initial_prompt_value", None) + roles = kwargs.get("roles", None) + final_prompt_value = kwargs.get("final_prompt_value", None) + bos_token = kwargs.get("bos_token", None) + eos_token = kwargs.get("eos_token", None) + preset_cache_key = kwargs.get("preset_cache_key", None) + hf_model_name = kwargs.get("hf_model_name", None) + supports_system_message = kwargs.get("supports_system_message", None) + ### TEXT COMPLETION CALLS ### + text_completion = kwargs.get("text_completion", False) + atext_completion = kwargs.get("atext_completion", False) + ### ASYNC CALLS ### + acompletion = kwargs.get("acompletion", False) + client = kwargs.get("client", None) + ### Admin Controls ### + no_log = kwargs.get("no-log", False) + ######## end of unpacking kwargs ########### + openai_params = [ + "functions", + "function_call", + "temperature", + "temperature", + "top_p", + "n", + "stream", + "stream_options", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "request_timeout", + "api_base", + "api_version", + "api_key", + "deployment_id", + "organization", + "base_url", + "default_headers", + "timeout", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", + "logprobs", + "top_logprobs", + "extra_headers", + ] + litellm_params = [ + "metadata", + "acompletion", + "atext_completion", + "text_completion", + "caching", + "mock_response", + "api_key", + "api_version", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "retry_policy", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "max_parallel_requests", + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_second", + "output_cost_per_second", + "hf_model_name", + "model_info", + "proxy_server_request", + "preset_cache_key", + "caching_groups", + "ttl", + "cache", + "no-log", + "base_model", + "stream_timeout", + "supports_system_message", + "region_name", + "allowed_model_region", + ] + default_params = openai_params + litellm_params + non_default_params = { + k: v for k, v in kwargs.items() if k not in default_params + } # model-specific params - pass them straight to the model/provider + + ### TIMEOUT LOGIC ### + timeout = timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + + try: + if base_url is not None: + api_base = base_url + if max_retries is not None: # openai allows openai.OpenAI(max_retries=3) + num_retries = max_retries + logging = litellm_logging_obj + fallbacks = fallbacks or litellm.model_fallbacks + if fallbacks is not None: + return completion_with_fallbacks(**args) + if model_list is not None: + deployments = [ + m["litellm_params"] for m in model_list if m["model_name"] == model + ] + return batch_completion_models(deployments=deployments, **args) + if litellm.model_alias_map and model in litellm.model_alias_map: + model = litellm.model_alias_map[ + model + ] # update the model to the actual value if an alias has been passed in + model_response = ModelResponse() + setattr(model_response, "usage", litellm.Usage()) + if ( + kwargs.get("azure", False) == True + ): # don't remove flag check, to remain backwards compatible for repos like Codium + custom_llm_provider = "azure" + if deployment_id != None: # azure llms + model = deployment_id + custom_llm_provider = "azure" + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + api_key=api_key, + ) + if model_response is not None and hasattr(model_response, "_hidden_params"): + model_response._hidden_params["custom_llm_provider"] = custom_llm_provider + model_response._hidden_params["region_name"] = kwargs.get( + "aws_region_name", None + ) # support region-based pricing for bedrock + + ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### + if input_cost_per_token is not None and output_cost_per_token is not None: + print_verbose(f"Registering model={model} in model cost map") + litellm.register_model( + { + f"{custom_llm_provider}/{model}": { + "input_cost_per_token": input_cost_per_token, + "output_cost_per_token": output_cost_per_token, + "litellm_provider": custom_llm_provider, + }, + model: { + "input_cost_per_token": input_cost_per_token, + "output_cost_per_token": output_cost_per_token, + "litellm_provider": custom_llm_provider, + }, + } + ) + elif ( + input_cost_per_second is not None + ): # time based pricing just needs cost in place + output_cost_per_second = output_cost_per_second + litellm.register_model( + { + f"{custom_llm_provider}/{model}": { + "input_cost_per_second": input_cost_per_second, + "output_cost_per_second": output_cost_per_second, + "litellm_provider": custom_llm_provider, + }, + model: { + "input_cost_per_second": input_cost_per_second, + "output_cost_per_second": output_cost_per_second, + "litellm_provider": custom_llm_provider, + }, + } + ) + ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### + custom_prompt_dict = {} # type: ignore + if ( + initial_prompt_value + or roles + or final_prompt_value + or bos_token + or eos_token + ): + custom_prompt_dict = {model: {}} + if initial_prompt_value: + custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value + if roles: + custom_prompt_dict[model]["roles"] = roles + if final_prompt_value: + custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value + if bos_token: + custom_prompt_dict[model]["bos_token"] = bos_token + if eos_token: + custom_prompt_dict[model]["eos_token"] = eos_token + + if ( + supports_system_message is not None + and isinstance(supports_system_message, bool) + and supports_system_message == False + ): + messages = map_system_message_pt(messages=messages) + model_api_key = get_api_key( + llm_provider=custom_llm_provider, dynamic_api_key=api_key + ) # get the api key from the environment if required for the model + + if dynamic_api_key is not None: + api_key = dynamic_api_key + # check if user passed in any of the OpenAI optional params + optional_params = get_optional_params( + functions=functions, + function_call=function_call, + temperature=temperature, + top_p=top_p, + n=n, + stream=stream, + stream_options=stream_options, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + user=user, + # params to identify the model + model=model, + custom_llm_provider=custom_llm_provider, + response_format=response_format, + seed=seed, + tools=tools, + tool_choice=tool_choice, + max_retries=max_retries, + logprobs=logprobs, + top_logprobs=top_logprobs, + extra_headers=extra_headers, + **non_default_params, + ) + + if litellm.add_function_to_prompt and optional_params.get( + "functions_unsupported_model", None + ): # if user opts to add it to prompt, when API doesn't support function calling + functions_unsupported_model = optional_params.pop( + "functions_unsupported_model" + ) + messages = function_call_prompt( + messages=messages, functions=functions_unsupported_model + ) + + # For logging - save the values of the litellm-specific params passed in + litellm_params = get_litellm_params( + acompletion=acompletion, + api_key=api_key, + force_timeout=force_timeout, + logger_fn=logger_fn, + verbose=verbose, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + litellm_call_id=kwargs.get("litellm_call_id", None), + model_alias_map=litellm.model_alias_map, + completion_call_id=id, + metadata=metadata, + model_info=model_info, + proxy_server_request=proxy_server_request, + preset_cache_key=preset_cache_key, + no_log=no_log, + ) + logging.update_environment_variables( + model=model, + user=user, + optional_params=optional_params, + litellm_params=litellm_params, + ) + if mock_response: + return mock_completion( + model, + messages, + stream=stream, + mock_response=mock_response, + logging=logging, + acompletion=acompletion, + ) + if custom_llm_provider == "azure": + # azure configs + api_type = get_secret("AZURE_API_TYPE") or "azure" + + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") + + api_version = ( + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) + + azure_ad_token = optional_params.get("extra_body", {}).pop( + "azure_ad_token", None + ) or get_secret("AZURE_AD_TOKEN") + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.AzureOpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## COMPLETION CALL + response = azure_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + api_key=api_key, + api_base=api_base, + api_version=api_version, + api_type=api_type, + azure_ad_token=azure_ad_token, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, # type: ignore + client=client, # pass AsyncAzureOpenAI, AzureOpenAI client + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={ + "headers": headers, + "api_version": api_version, + "api_base": api_base, + }, + ) + elif custom_llm_provider == "azure_text": + # azure configs + api_type = get_secret("AZURE_API_TYPE") or "azure" + + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") + + api_version = ( + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) + + azure_ad_token = optional_params.get("extra_body", {}).pop( + "azure_ad_token", None + ) or get_secret("AZURE_AD_TOKEN") + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.AzureOpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## COMPLETION CALL + response = azure_text_completions.completion( + model=model, + messages=messages, + headers=headers, + api_key=api_key, + api_base=api_base, + api_version=api_version, + api_type=api_type, + azure_ad_token=azure_ad_token, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, + client=client, # pass AsyncAzureOpenAI, AzureOpenAI client + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={ + "headers": headers, + "api_version": api_version, + "api_base": api_base, + }, + ) + elif ( + model in litellm.open_ai_chat_completion_models + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "deepinfra" + or custom_llm_provider == "perplexity" + or custom_llm_provider == "groq" + or custom_llm_provider == "deepseek" + or custom_llm_provider == "anyscale" + or custom_llm_provider == "mistral" + or custom_llm_provider == "openai" + or custom_llm_provider == "together_ai" + or custom_llm_provider in litellm.openai_compatible_providers + or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo + ): # allow user to make an openai call with a custom base + # note: if a user sets a custom base - we should ensure this works + # allow for the setting of dynamic and stateful api-bases + api_base = ( + api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or get_secret("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + openai.organization = ( + organization + or litellm.organization + or get_secret("OPENAI_ORGANIZATION") + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or get_secret("OPENAI_API_KEY") + ) + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.OpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## COMPLETION CALL + try: + response = openai_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + custom_prompt_dict=custom_prompt_dict, + client=client, # pass AsyncOpenAI, OpenAI client + organization=organization, + custom_llm_provider=custom_llm_provider, + ) + except Exception as e: + ## LOGGING - log the original exception returned + logging.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"headers": headers}, + ) + raise e + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={"headers": headers}, + ) + elif ( + custom_llm_provider == "text-completion-openai" + or "ft:babbage-002" in model + or "ft:davinci-002" in model # support for finetuned completion models + ): + openai.api_type = "openai" + + api_base = ( + api_base + or litellm.api_base + or get_secret("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + + openai.api_version = None + # set API KEY + + api_key = ( + api_key + or litellm.api_key + or litellm.openai_key + or get_secret("OPENAI_API_KEY") + ) + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.OpenAITextCompletionConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + if litellm.organization: + openai.organization = litellm.organization + + if ( + len(messages) > 0 + and "content" in messages[0] + and type(messages[0]["content"]) == list + ): + # text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content'] + # https://platform.openai.com/docs/api-reference/completions/create + prompt = messages[0]["content"] + else: + prompt = " ".join([message["content"] for message in messages]) # type: ignore + + ## COMPLETION CALL + _response = openai_text_completions.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + client=client, # pass AsyncOpenAI, OpenAI client + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + ) + + if ( + optional_params.get("stream", False) == False + and acompletion == False + and text_completion == False + ): + # convert to chat completion response + _response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( + response_object=_response, model_response_object=model_response + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=_response, + additional_args={"headers": headers}, + ) + response = _response + elif ( + "replicate" in model + or custom_llm_provider == "replicate" + or model in litellm.replicate_models + ): + # Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN") + replicate_key = None + replicate_key = ( + api_key + or litellm.replicate_key + or litellm.api_key + or get_secret("REPLICATE_API_KEY") + or get_secret("REPLICATE_API_TOKEN") + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("REPLICATE_API_BASE") + or "https://api.replicate.com/v1" + ) + + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + + model_response = replicate.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=replicate_key, + logging_obj=logging, + custom_prompt_dict=custom_prompt_dict, + ) + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=replicate_key, + original_response=model_response, + ) + + response = model_response + + elif custom_llm_provider == "anthropic": + api_key = ( + api_key + or litellm.anthropic_key + or litellm.api_key + or os.environ.get("ANTHROPIC_API_KEY") + ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + + if (model == "claude-2") or (model == "claude-instant-1"): + # call anthropic /completion, only use this route for claude-2, claude-instant-1 + api_base = ( + api_base + or litellm.api_base + or get_secret("ANTHROPIC_API_BASE") + or "https://api.anthropic.com/v1/complete" + ) + response = anthropic_text_completions.completion( + model=model, + messages=messages, + api_base=api_base, + acompletion=acompletion, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=api_key, + logging_obj=logging, + headers=headers, + ) + else: + # call /messages + # default route for all anthropic models + api_base = ( + api_base + or litellm.api_base + or get_secret("ANTHROPIC_API_BASE") + or "https://api.anthropic.com/v1/messages" + ) + response = anthropic_chat_completions.completion( + model=model, + messages=messages, + api_base=api_base, + acompletion=acompletion, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=api_key, + logging_obj=logging, + headers=headers, + ) + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + ) + response = response + elif custom_llm_provider == "nlp_cloud": + nlp_cloud_key = ( + api_key + or litellm.nlp_cloud_key + or get_secret("NLP_CLOUD_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("NLP_CLOUD_API_BASE") + or "https://api.nlpcloud.io/v1/gpu/" + ) + + response = nlp_cloud.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=nlp_cloud_key, + logging_obj=logging, + ) + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + response, + model, + custom_llm_provider="nlp_cloud", + logging_obj=logging, + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + ) + + response = response + elif custom_llm_provider == "aleph_alpha": + aleph_alpha_key = ( + api_key + or litellm.aleph_alpha_key + or get_secret("ALEPH_ALPHA_API_KEY") + or get_secret("ALEPHALPHA_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("ALEPH_ALPHA_API_BASE") + or "https://api.aleph-alpha.com/complete" + ) + + model_response = aleph_alpha.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + default_max_tokens_to_sample=litellm.max_tokens, + api_key=aleph_alpha_key, + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + ) + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="aleph_alpha", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "cohere": + cohere_key = ( + api_key + or litellm.cohere_key + or get_secret("COHERE_API_KEY") + or get_secret("CO_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("COHERE_API_BASE") + or "https://api.cohere.ai/v1/generate" + ) + + model_response = cohere.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=cohere_key, + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + ) + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="cohere", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "cohere_chat": + cohere_key = ( + api_key + or litellm.cohere_key + or get_secret("COHERE_API_KEY") + or get_secret("CO_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("COHERE_API_BASE") + or "https://api.cohere.ai/v1/chat" + ) + + model_response = cohere_chat.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=cohere_key, + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + ) + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="cohere_chat", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "maritalk": + maritalk_key = ( + api_key + or litellm.maritalk_key + or get_secret("MARITALK_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("MARITALK_API_BASE") + or "https://chat.maritaca.ai/api/chat/inference" + ) + + model_response = maritalk.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=maritalk_key, + logging_obj=logging, + ) + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="maritalk", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "huggingface": + custom_llm_provider = "huggingface" + huggingface_key = ( + api_key + or litellm.huggingface_key + or os.environ.get("HF_TOKEN") + or os.environ.get("HUGGINGFACE_API_KEY") + or litellm.api_key + ) + hf_headers = headers or litellm.headers + + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + model_response = huggingface.completion( + model=model, + messages=messages, + api_base=api_base, # type: ignore + headers=hf_headers, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=huggingface_key, + acompletion=acompletion, + logging_obj=logging, + custom_prompt_dict=custom_prompt_dict, + timeout=timeout, # type: ignore + ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion is False + ): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="huggingface", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "oobabooga": + custom_llm_provider = "oobabooga" + model_response = oobabooga.completion( + model=model, + messages=messages, + model_response=model_response, + api_base=api_base, # type: ignore + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + api_key=None, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + ) + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="oobabooga", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "openrouter": + api_base = api_base or litellm.api_base or "https://openrouter.ai/api/v1" + + api_key = ( + api_key + or litellm.api_key + or litellm.openrouter_key + or get_secret("OPENROUTER_API_KEY") + or get_secret("OR_API_KEY") + ) + + openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai" + + openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM" + + headers = ( + headers + or litellm.headers + or { + "HTTP-Referer": openrouter_site_url, + "X-Title": openrouter_app_name, + } + ) + + ## Load Config + config = openrouter.OpenrouterConfig.get_config() + for k, v in config.items(): + if k == "extra_body": + # we use openai 'extra_body' to pass openrouter specific params - transforms, route, models + if "extra_body" in optional_params: + optional_params[k].update(v) + else: + optional_params[k] = v + elif k not in optional_params: + optional_params[k] = v + + data = {"model": model, "messages": messages, **optional_params} + + ## COMPLETION CALL + response = openai_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + api_key=api_key, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, # type: ignore + ) + ## LOGGING + logging.post_call( + input=messages, api_key=openai.api_key, original_response=response + ) + elif ( + custom_llm_provider == "together_ai" + or ("togethercomputer" in model) + or (model in litellm.together_ai_models) + ): + """ + Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility + """ + custom_llm_provider = "together_ai" + together_ai_key = ( + api_key + or litellm.togetherai_api_key + or get_secret("TOGETHER_AI_TOKEN") + or get_secret("TOGETHERAI_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("TOGETHERAI_API_BASE") + or "https://api.together.xyz/inference" + ) + + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + + model_response = together_ai.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=together_ai_key, + logging_obj=logging, + custom_prompt_dict=custom_prompt_dict, + ) + if ( + "stream_tokens" in optional_params + and optional_params["stream_tokens"] == True + ): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="together_ai", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "palm": + palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key + + # palm does not support streaming as yet :( + model_response = palm.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=palm_api_key, + logging_obj=logging, + ) + # fake palm streaming + if "stream" in optional_params and optional_params["stream"] == True: + # fake streaming for palm + resp_string = model_response["choices"][0]["message"]["content"] + response = CustomStreamWrapper( + resp_string, model, custom_llm_provider="palm", logging_obj=logging + ) + return response + response = model_response + elif custom_llm_provider == "gemini": + gemini_api_key = ( + api_key + or get_secret("GEMINI_API_KEY") + or get_secret("PALM_API_KEY") # older palm api key should also work + or litellm.api_key + ) + + # palm does not support streaming as yet :( + model_response = gemini.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=gemini_api_key, + logging_obj=logging, + acompletion=acompletion, + custom_prompt_dict=custom_prompt_dict, + ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion == False + ): + response = CustomStreamWrapper( + iter(model_response), + model, + custom_llm_provider="gemini", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "vertex_ai": + vertex_ai_project = ( + optional_params.pop("vertex_project", None) + or optional_params.pop("vertex_ai_project", None) + or litellm.vertex_project + or get_secret("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + optional_params.pop("vertex_location", None) + or optional_params.pop("vertex_ai_location", None) + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + ) + vertex_credentials = ( + optional_params.pop("vertex_credentials", None) + or optional_params.pop("vertex_ai_credentials", None) + or get_secret("VERTEXAI_CREDENTIALS") + ) + new_params = deepcopy(optional_params) + if "claude-3" in model: + model_response = vertex_ai_anthropic.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + ) + else: + model_response = vertex_ai.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion == False + ): + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="vertex_ai", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "predibase": + tenant_id = ( + optional_params.pop("tenant_id", None) + or optional_params.pop("predibase_tenant_id", None) + or litellm.predibase_tenant_id + or get_secret("PREDIBASE_TENANT_ID") + ) + + api_base = ( + optional_params.pop("api_base", None) + or optional_params.pop("base_url", None) + or litellm.api_base + or get_secret("PREDIBASE_API_BASE") + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.predibase_key + or get_secret("PREDIBASE_API_KEY") + ) + +> model_response = predibase_chat_completions.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + acompletion=acompletion, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + api_key=api_key, + tenant_id=tenant_id, + ) + +../main.py:1813: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +self = +model = 'llama-3-8b-instruct' +messages = [{'content': 'What is the meaning of life?', 'role': 'user'}] +api_base = None, custom_prompt_dict = {} +model_response = ModelResponse(id='chatcmpl-755fcb98-22ba-46a2-9d6d-1a85b4363e98', choices=[Choices(finish_reason='stop', index=0, mess... role='assistant'))], created=1715301477, model=None, object='chat.completion', system_fingerprint=None, usage=Usage()) +print_verbose = +encoding = , api_key = 'pb_Qg9YbQo7UqqHdu0ozxN_aw' +logging_obj = +optional_params = {'details': True, 'max_new_tokens': 256, 'return_full_text': False} +tenant_id = 'c4768f95', acompletion = False +litellm_params = {'acompletion': False, 'api_base': 'https://serving.app.predibase.com/c4768f95/deployments/v2/llms/llama-3-8b-instruct/generate_stream', 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'completion_call_id': None, ...} +logger_fn = None +headers = {'Authorization': 'Bearer pb_Qg9YbQo7UqqHdu0ozxN_aw', 'content-type': 'application/json'} + + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key: str, + logging_obj, + optional_params: dict, + tenant_id: str, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers: dict = {}, + ) -> Union[ModelResponse, CustomStreamWrapper]: + headers = self.validate_environment(api_key, headers) + completion_url = "" + input_text = "" + base_url = "https://serving.app.predibase.com" + if "https" in model: + completion_url = model + elif api_base: + base_url = api_base + elif "PREDIBASE_API_BASE" in os.environ: + base_url = os.getenv("PREDIBASE_API_BASE", "") + + completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}" + + if optional_params.get("stream", False) == True: + completion_url += "/generate_stream" + else: + completion_url += "/generate" + + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) + else: + prompt = prompt_factory(model=model, messages=messages) + + ## Load Config + config = litellm.PredibaseConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + stream = optional_params.pop("stream", False) + + data = { + "inputs": prompt, + "parameters": optional_params, + } + input_text = prompt + ## LOGGING + logging_obj.pre_call( + input=input_text, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": completion_url, + "acompletion": acompletion, + }, + ) + ## COMPLETION CALL + if acompletion is True: + ### ASYNC STREAMING + if stream == True: + return self.async_streaming( + model=model, + messages=messages, + data=data, + api_base=completion_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + ) # type: ignore + else: + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=data, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=False, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + ) # type: ignore + + ### SYNC STREAMING + if stream == True: + response = requests.post( + completion_url, + headers=headers, + data=json.dumps(data), +> stream=optional_params["stream"], + ) +E KeyError: 'stream' + +../llms/predibase.py:412: KeyError + +During handling of the above exception, another exception occurred: + +sync_mode = True + + @pytest.mark.parametrize("sync_mode", [True, False]) + @pytest.mark.asyncio + async def test_completion_predibase_streaming(sync_mode): + try: + litellm.set_verbose = True + + if sync_mode: +> response = completion( + model="predibase/llama-3-8b-instruct", + tenant_id="c4768f95", + api_base="https://serving.app.predibase.com", + api_key=os.getenv("PREDIBASE_API_KEY"), + messages=[{"role": "user", "content": "What is the meaning of life?"}], + stream=True, + ) + +test_streaming.py:317: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +args = () +kwargs = {'api_base': 'https://serving.app.predibase.com', 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'litellm_call_id': 'cf0ea464-1b45-4473-8e55-6bf6809df7a7', 'litellm_logging_obj': , ...} +result = None, start_time = datetime.datetime(2024, 5, 9, 17, 37, 57, 884661) +logging_obj = +call_type = 'completion', model = 'predibase/llama-3-8b-instruct' +k = 'litellm_logging_obj' + + @wraps(original_function) + def wrapper(*args, **kwargs): + # DO NOT MOVE THIS. It always needs to run first + # Check if this is an async function. If so only execute the async function + if ( + kwargs.get("acompletion", False) == True + or kwargs.get("aembedding", False) == True + or kwargs.get("aimg_generation", False) == True + or kwargs.get("amoderation", False) == True + or kwargs.get("atext_completion", False) == True + or kwargs.get("atranscription", False) == True + ): + # [OPTIONAL] CHECK MAX RETRIES / REQUEST + if litellm.num_retries_per_request is not None: + # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] + previous_models = kwargs.get("metadata", {}).get( + "previous_models", None + ) + if previous_models is not None: + if litellm.num_retries_per_request <= len(previous_models): + raise Exception(f"Max retries per request hit!") + + # MODEL CALL + result = original_function(*args, **kwargs) + if "stream" in kwargs and kwargs["stream"] == True: + if ( + "complete_response" in kwargs + and kwargs["complete_response"] == True + ): + chunks = [] + for idx, chunk in enumerate(result): + chunks.append(chunk) + return litellm.stream_chunk_builder( + chunks, messages=kwargs.get("messages", None) + ) + else: + return result + return result + + # Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print + print_args_passed_to_litellm(original_function, args, kwargs) + start_time = datetime.datetime.now() + result = None + logging_obj = kwargs.get("litellm_logging_obj", None) + + # only set litellm_call_id if its not in kwargs + call_type = original_function.__name__ + if "litellm_call_id" not in kwargs: + kwargs["litellm_call_id"] = str(uuid.uuid4()) + try: + model = args[0] if len(args) > 0 else kwargs["model"] + except: + model = None + if ( + call_type != CallTypes.image_generation.value + and call_type != CallTypes.text_completion.value + ): + raise ValueError("model param not passed in.") + + try: + if logging_obj is None: + logging_obj, kwargs = function_setup( + original_function.__name__, rules_obj, start_time, *args, **kwargs + ) + kwargs["litellm_logging_obj"] = logging_obj + + # CHECK FOR 'os.environ/' in kwargs + for k, v in kwargs.items(): + if v is not None and isinstance(v, str) and v.startswith("os.environ/"): + kwargs[k] = litellm.get_secret(v) + # [OPTIONAL] CHECK BUDGET + if litellm.max_budget: + if litellm._current_cost > litellm.max_budget: + raise BudgetExceededError( + current_cost=litellm._current_cost, + max_budget=litellm.max_budget, + ) + + # [OPTIONAL] CHECK MAX RETRIES / REQUEST + if litellm.num_retries_per_request is not None: + # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] + previous_models = kwargs.get("metadata", {}).get( + "previous_models", None + ) + if previous_models is not None: + if litellm.num_retries_per_request <= len(previous_models): + raise Exception(f"Max retries per request hit!") + + # [OPTIONAL] CHECK CACHE + print_verbose( + f"SYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache')['no-cache']: {kwargs.get('cache', {}).get('no-cache', False)}" + ) + # if caching is false or cache["no-cache"]==True, don't run this + if ( + ( + ( + ( + kwargs.get("caching", None) is None + and litellm.cache is not None + ) + or kwargs.get("caching", False) == True + ) + and kwargs.get("cache", {}).get("no-cache", False) != True + ) + and kwargs.get("aembedding", False) != True + and kwargs.get("atext_completion", False) != True + and kwargs.get("acompletion", False) != True + and kwargs.get("aimg_generation", False) != True + and kwargs.get("atranscription", False) != True + ): # allow users to control returning cached responses from the completion function + # checking cache + print_verbose(f"INSIDE CHECKING CACHE") + if ( + litellm.cache is not None + and str(original_function.__name__) + in litellm.cache.supported_call_types + ): + print_verbose(f"Checking Cache") + preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) + kwargs["preset_cache_key"] = ( + preset_cache_key # for streaming calls, we need to pass the preset_cache_key + ) + cached_result = litellm.cache.get_cache(*args, **kwargs) + if cached_result != None: + if "detail" in cached_result: + # implies an error occurred + pass + else: + call_type = original_function.__name__ + print_verbose( + f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}" + ) + if call_type == CallTypes.completion.value and isinstance( + cached_result, dict + ): + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=ModelResponse(), + stream=kwargs.get("stream", False), + ) + if kwargs.get("stream", False) == True: + cached_result = CustomStreamWrapper( + completion_stream=cached_result, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + elif call_type == CallTypes.embedding.value and isinstance( + cached_result, dict + ): + cached_result = convert_to_model_response_object( + response_object=cached_result, + response_type="embedding", + ) + + # LOG SUCCESS + cache_hit = True + end_time = datetime.datetime.now() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=kwargs.get( + "custom_llm_provider", None + ), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + print_verbose( + f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" + ) + logging_obj.update_environment_variables( + model=model, + user=kwargs.get("user", None), + optional_params={}, + litellm_params={ + "logger_fn": kwargs.get("logger_fn", None), + "acompletion": False, + "metadata": kwargs.get("metadata", {}), + "model_info": kwargs.get("model_info", {}), + "proxy_server_request": kwargs.get( + "proxy_server_request", None + ), + "preset_cache_key": kwargs.get( + "preset_cache_key", None + ), + "stream_response": kwargs.get( + "stream_response", {} + ), + }, + input=kwargs.get("messages", ""), + api_key=kwargs.get("api_key", None), + original_response=str(cached_result), + additional_args=None, + stream=kwargs.get("stream", False), + ) + threading.Thread( + target=logging_obj.success_handler, + args=(cached_result, start_time, end_time, cache_hit), + ).start() + return cached_result + + # CHECK MAX TOKENS + if ( + kwargs.get("max_tokens", None) is not None + and model is not None + and litellm.modify_params + == True # user is okay with params being modified + and ( + call_type == CallTypes.acompletion.value + or call_type == CallTypes.completion.value + ) + ): + try: + base_model = model + if kwargs.get("hf_model_name", None) is not None: + base_model = f"huggingface/{kwargs.get('hf_model_name')}" + max_output_tokens = ( + get_max_tokens(model=base_model) or 4096 + ) # assume min context window is 4k tokens + user_max_tokens = kwargs.get("max_tokens") + ## Scenario 1: User limit + prompt > model limit + messages = None + if len(args) > 1: + messages = args[1] + elif kwargs.get("messages", None): + messages = kwargs["messages"] + input_tokens = token_counter(model=base_model, messages=messages) + input_tokens += max( + 0.1 * input_tokens, 10 + ) # give at least a 10 token buffer. token counting can be imprecise. + if input_tokens > max_output_tokens: + pass # allow call to fail normally + elif user_max_tokens + input_tokens > max_output_tokens: + user_max_tokens = max_output_tokens - input_tokens + print_verbose(f"user_max_tokens: {user_max_tokens}") + kwargs["max_tokens"] = int( + round(user_max_tokens) + ) # make sure max tokens is always an int + except Exception as e: + print_verbose(f"Error while checking max token limit: {str(e)}") + # MODEL CALL + result = original_function(*args, **kwargs) + end_time = datetime.datetime.now() + if "stream" in kwargs and kwargs["stream"] == True: + if ( + "complete_response" in kwargs + and kwargs["complete_response"] == True + ): + chunks = [] + for idx, chunk in enumerate(result): + chunks.append(chunk) + return litellm.stream_chunk_builder( + chunks, messages=kwargs.get("messages", None) + ) + else: + return result + elif "acompletion" in kwargs and kwargs["acompletion"] == True: + return result + elif "aembedding" in kwargs and kwargs["aembedding"] == True: + return result + elif "aimg_generation" in kwargs and kwargs["aimg_generation"] == True: + return result + elif "atranscription" in kwargs and kwargs["atranscription"] == True: + return result + + ### POST-CALL RULES ### + post_call_processing(original_response=result, model=model or None) + + # [OPTIONAL] ADD TO CACHE + if ( + litellm.cache is not None + and str(original_function.__name__) + in litellm.cache.supported_call_types + ) and (kwargs.get("cache", {}).get("no-store", False) != True): + litellm.cache.add_cache(result, *args, **kwargs) + + # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated + verbose_logger.info(f"Wrapper: Completed Call, calling success_handler") + threading.Thread( + target=logging_obj.success_handler, args=(result, start_time, end_time) + ).start() + # RETURN RESULT + if hasattr(result, "_hidden_params"): + result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( + "id", None + ) + result._hidden_params["api_base"] = get_api_base( + model=model, + optional_params=getattr(logging_obj, "optional_params", {}), + ) + result._response_ms = ( + end_time - start_time + ).total_seconds() * 1000 # return response latency in ms like openai + return result + except Exception as e: + call_type = original_function.__name__ + if call_type == CallTypes.completion.value: + num_retries = ( + kwargs.get("num_retries", None) or litellm.num_retries or None + ) + litellm.num_retries = ( + None # set retries to None to prevent infinite loops + ) + context_window_fallback_dict = kwargs.get( + "context_window_fallback_dict", {} + ) + + _is_litellm_router_call = "model_group" in kwargs.get( + "metadata", {} + ) # check if call from litellm.router/proxy + if ( + num_retries and not _is_litellm_router_call + ): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying + if ( + isinstance(e, openai.APIError) + or isinstance(e, openai.Timeout) + or isinstance(e, openai.APIConnectionError) + ): + kwargs["num_retries"] = num_retries + return litellm.completion_with_retries(*args, **kwargs) + elif ( + isinstance(e, litellm.exceptions.ContextWindowExceededError) + and context_window_fallback_dict + and model in context_window_fallback_dict + ): + if len(args) > 0: + args[0] = context_window_fallback_dict[model] + else: + kwargs["model"] = context_window_fallback_dict[model] + return original_function(*args, **kwargs) + traceback_exception = traceback.format_exc() + end_time = datetime.datetime.now() + # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated + if logging_obj: + logging_obj.failure_handler( + e, traceback_exception, start_time, end_time + ) # DO NOT MAKE THREADED - router retry fallback relies on this! + my_thread = threading.Thread( + target=handle_failure, + args=(e, traceback_exception, start_time, end_time, args, kwargs), + ) # don't interrupt execution of main thread + my_thread.start() + if hasattr(e, "message"): + if ( + liteDebuggerClient and liteDebuggerClient.dashboard_url != None + ): # make it easy to get to the debugger logs if you've initialized it + e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}" +> raise e + +../utils.py:3229: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +args = () +kwargs = {'api_base': 'https://serving.app.predibase.com', 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'litellm_call_id': 'cf0ea464-1b45-4473-8e55-6bf6809df7a7', 'litellm_logging_obj': , ...} +result = None, start_time = datetime.datetime(2024, 5, 9, 17, 37, 57, 884661) +logging_obj = +call_type = 'completion', model = 'predibase/llama-3-8b-instruct' +k = 'litellm_logging_obj' + + @wraps(original_function) + def wrapper(*args, **kwargs): + # DO NOT MOVE THIS. It always needs to run first + # Check if this is an async function. If so only execute the async function + if ( + kwargs.get("acompletion", False) == True + or kwargs.get("aembedding", False) == True + or kwargs.get("aimg_generation", False) == True + or kwargs.get("amoderation", False) == True + or kwargs.get("atext_completion", False) == True + or kwargs.get("atranscription", False) == True + ): + # [OPTIONAL] CHECK MAX RETRIES / REQUEST + if litellm.num_retries_per_request is not None: + # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] + previous_models = kwargs.get("metadata", {}).get( + "previous_models", None + ) + if previous_models is not None: + if litellm.num_retries_per_request <= len(previous_models): + raise Exception(f"Max retries per request hit!") + + # MODEL CALL + result = original_function(*args, **kwargs) + if "stream" in kwargs and kwargs["stream"] == True: + if ( + "complete_response" in kwargs + and kwargs["complete_response"] == True + ): + chunks = [] + for idx, chunk in enumerate(result): + chunks.append(chunk) + return litellm.stream_chunk_builder( + chunks, messages=kwargs.get("messages", None) + ) + else: + return result + return result + + # Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print + print_args_passed_to_litellm(original_function, args, kwargs) + start_time = datetime.datetime.now() + result = None + logging_obj = kwargs.get("litellm_logging_obj", None) + + # only set litellm_call_id if its not in kwargs + call_type = original_function.__name__ + if "litellm_call_id" not in kwargs: + kwargs["litellm_call_id"] = str(uuid.uuid4()) + try: + model = args[0] if len(args) > 0 else kwargs["model"] + except: + model = None + if ( + call_type != CallTypes.image_generation.value + and call_type != CallTypes.text_completion.value + ): + raise ValueError("model param not passed in.") + + try: + if logging_obj is None: + logging_obj, kwargs = function_setup( + original_function.__name__, rules_obj, start_time, *args, **kwargs + ) + kwargs["litellm_logging_obj"] = logging_obj + + # CHECK FOR 'os.environ/' in kwargs + for k, v in kwargs.items(): + if v is not None and isinstance(v, str) and v.startswith("os.environ/"): + kwargs[k] = litellm.get_secret(v) + # [OPTIONAL] CHECK BUDGET + if litellm.max_budget: + if litellm._current_cost > litellm.max_budget: + raise BudgetExceededError( + current_cost=litellm._current_cost, + max_budget=litellm.max_budget, + ) + + # [OPTIONAL] CHECK MAX RETRIES / REQUEST + if litellm.num_retries_per_request is not None: + # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] + previous_models = kwargs.get("metadata", {}).get( + "previous_models", None + ) + if previous_models is not None: + if litellm.num_retries_per_request <= len(previous_models): + raise Exception(f"Max retries per request hit!") + + # [OPTIONAL] CHECK CACHE + print_verbose( + f"SYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache')['no-cache']: {kwargs.get('cache', {}).get('no-cache', False)}" + ) + # if caching is false or cache["no-cache"]==True, don't run this + if ( + ( + ( + ( + kwargs.get("caching", None) is None + and litellm.cache is not None + ) + or kwargs.get("caching", False) == True + ) + and kwargs.get("cache", {}).get("no-cache", False) != True + ) + and kwargs.get("aembedding", False) != True + and kwargs.get("atext_completion", False) != True + and kwargs.get("acompletion", False) != True + and kwargs.get("aimg_generation", False) != True + and kwargs.get("atranscription", False) != True + ): # allow users to control returning cached responses from the completion function + # checking cache + print_verbose(f"INSIDE CHECKING CACHE") + if ( + litellm.cache is not None + and str(original_function.__name__) + in litellm.cache.supported_call_types + ): + print_verbose(f"Checking Cache") + preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) + kwargs["preset_cache_key"] = ( + preset_cache_key # for streaming calls, we need to pass the preset_cache_key + ) + cached_result = litellm.cache.get_cache(*args, **kwargs) + if cached_result != None: + if "detail" in cached_result: + # implies an error occurred + pass + else: + call_type = original_function.__name__ + print_verbose( + f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}" + ) + if call_type == CallTypes.completion.value and isinstance( + cached_result, dict + ): + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=ModelResponse(), + stream=kwargs.get("stream", False), + ) + if kwargs.get("stream", False) == True: + cached_result = CustomStreamWrapper( + completion_stream=cached_result, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + elif call_type == CallTypes.embedding.value and isinstance( + cached_result, dict + ): + cached_result = convert_to_model_response_object( + response_object=cached_result, + response_type="embedding", + ) + + # LOG SUCCESS + cache_hit = True + end_time = datetime.datetime.now() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=kwargs.get( + "custom_llm_provider", None + ), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + print_verbose( + f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" + ) + logging_obj.update_environment_variables( + model=model, + user=kwargs.get("user", None), + optional_params={}, + litellm_params={ + "logger_fn": kwargs.get("logger_fn", None), + "acompletion": False, + "metadata": kwargs.get("metadata", {}), + "model_info": kwargs.get("model_info", {}), + "proxy_server_request": kwargs.get( + "proxy_server_request", None + ), + "preset_cache_key": kwargs.get( + "preset_cache_key", None + ), + "stream_response": kwargs.get( + "stream_response", {} + ), + }, + input=kwargs.get("messages", ""), + api_key=kwargs.get("api_key", None), + original_response=str(cached_result), + additional_args=None, + stream=kwargs.get("stream", False), + ) + threading.Thread( + target=logging_obj.success_handler, + args=(cached_result, start_time, end_time, cache_hit), + ).start() + return cached_result + + # CHECK MAX TOKENS + if ( + kwargs.get("max_tokens", None) is not None + and model is not None + and litellm.modify_params + == True # user is okay with params being modified + and ( + call_type == CallTypes.acompletion.value + or call_type == CallTypes.completion.value + ) + ): + try: + base_model = model + if kwargs.get("hf_model_name", None) is not None: + base_model = f"huggingface/{kwargs.get('hf_model_name')}" + max_output_tokens = ( + get_max_tokens(model=base_model) or 4096 + ) # assume min context window is 4k tokens + user_max_tokens = kwargs.get("max_tokens") + ## Scenario 1: User limit + prompt > model limit + messages = None + if len(args) > 1: + messages = args[1] + elif kwargs.get("messages", None): + messages = kwargs["messages"] + input_tokens = token_counter(model=base_model, messages=messages) + input_tokens += max( + 0.1 * input_tokens, 10 + ) # give at least a 10 token buffer. token counting can be imprecise. + if input_tokens > max_output_tokens: + pass # allow call to fail normally + elif user_max_tokens + input_tokens > max_output_tokens: + user_max_tokens = max_output_tokens - input_tokens + print_verbose(f"user_max_tokens: {user_max_tokens}") + kwargs["max_tokens"] = int( + round(user_max_tokens) + ) # make sure max tokens is always an int + except Exception as e: + print_verbose(f"Error while checking max token limit: {str(e)}") + # MODEL CALL +> result = original_function(*args, **kwargs) + +../utils.py:3123: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +model = 'llama-3-8b-instruct' +messages = [{'content': 'What is the meaning of life?', 'role': 'user'}] +timeout = 600.0, temperature = None, top_p = None, n = None, stream = True +stream_options = None, stop = None, max_tokens = None, presence_penalty = None +frequency_penalty = None, logit_bias = None, user = None, response_format = None +seed = None, tools = None, tool_choice = None, logprobs = None +top_logprobs = None, deployment_id = None, extra_headers = None +functions = None, function_call = None, base_url = None, api_version = None +api_key = 'pb_Qg9YbQo7UqqHdu0ozxN_aw', model_list = None +kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_id': 'cf0ea464-1b45-4473-8e55-6bf6809df7a7', 'litellm_logging_obj': , 'tenant_id': 'c4768f95'} +args = {'acompletion': False, 'api_base': None, 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'api_version': None, ...} +api_base = None, mock_response = None, force_timeout = 600, logger_fn = None +verbose = False, custom_llm_provider = 'predibase' + + @client + def completion( + model: str, + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + messages: List = [], + timeout: Optional[Union[float, str, httpx.Timeout]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream: Optional[bool] = None, + stream_options: Optional[dict] = None, + stop=None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[dict] = None, + user: Optional[str] = None, + # openai v1.0+ new params + response_format: Optional[dict] = None, + seed: Optional[int] = None, + tools: Optional[List] = None, + tool_choice: Optional[str] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + deployment_id=None, + extra_headers: Optional[dict] = None, + # soon to be deprecated params by OpenAI + functions: Optional[List] = None, + function_call: Optional[str] = None, + # set api_base, api_version, api_key + base_url: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + # Optional liteLLM function params + **kwargs, + ) -> Union[ModelResponse, CustomStreamWrapper]: + """ + Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) + Parameters: + model (str): The name of the language model to use for text completion. see all supported LLMs: https://docs.litellm.ai/docs/providers/ + messages (List): A list of message objects representing the conversation context (default is an empty list). + + OPTIONAL PARAMS + functions (List, optional): A list of functions to apply to the conversation messages (default is an empty list). + function_call (str, optional): The name of the function to call within the conversation (default is an empty string). + temperature (float, optional): The temperature parameter for controlling the randomness of the output (default is 1.0). + top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). + n (int, optional): The number of completions to generate (default is 1). + stream (bool, optional): If True, return a streaming response (default is False). + stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true. + stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. + max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). + presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. + frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. + logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. + user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. + logprobs (bool, optional): Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message + top_logprobs (int, optional): An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. + metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. + api_base (str, optional): Base URL for the API (default is None). + api_version (str, optional): API version (default is None). + api_key (str, optional): API key (default is None). + model_list (list, optional): List of api base, version, keys + extra_headers (dict, optional): Additional headers to include in the request. + + LITELLM Specific Params + mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None). + custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock" + max_retries (int, optional): The number of retries to attempt (default is 0). + Returns: + ModelResponse: A response object containing the generated completion and associated metadata. + + Note: + - This function is used to perform completions() using the specified language model. + - It supports various optional parameters for customizing the completion behavior. + - If 'mock_response' is provided, a mock completion response is returned for testing or debugging. + """ + ######### unpacking kwargs ##################### + args = locals() + api_base = kwargs.get("api_base", None) + mock_response = kwargs.get("mock_response", None) + force_timeout = kwargs.get("force_timeout", 600) ## deprecated + logger_fn = kwargs.get("logger_fn", None) + verbose = kwargs.get("verbose", False) + custom_llm_provider = kwargs.get("custom_llm_provider", None) + litellm_logging_obj = kwargs.get("litellm_logging_obj", None) + id = kwargs.get("id", None) + metadata = kwargs.get("metadata", None) + model_info = kwargs.get("model_info", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + fallbacks = kwargs.get("fallbacks", None) + headers = kwargs.get("headers", None) + num_retries = kwargs.get("num_retries", None) ## deprecated + max_retries = kwargs.get("max_retries", None) + context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) + organization = kwargs.get("organization", None) + ### CUSTOM MODEL COST ### + input_cost_per_token = kwargs.get("input_cost_per_token", None) + output_cost_per_token = kwargs.get("output_cost_per_token", None) + input_cost_per_second = kwargs.get("input_cost_per_second", None) + output_cost_per_second = kwargs.get("output_cost_per_second", None) + ### CUSTOM PROMPT TEMPLATE ### + initial_prompt_value = kwargs.get("initial_prompt_value", None) + roles = kwargs.get("roles", None) + final_prompt_value = kwargs.get("final_prompt_value", None) + bos_token = kwargs.get("bos_token", None) + eos_token = kwargs.get("eos_token", None) + preset_cache_key = kwargs.get("preset_cache_key", None) + hf_model_name = kwargs.get("hf_model_name", None) + supports_system_message = kwargs.get("supports_system_message", None) + ### TEXT COMPLETION CALLS ### + text_completion = kwargs.get("text_completion", False) + atext_completion = kwargs.get("atext_completion", False) + ### ASYNC CALLS ### + acompletion = kwargs.get("acompletion", False) + client = kwargs.get("client", None) + ### Admin Controls ### + no_log = kwargs.get("no-log", False) + ######## end of unpacking kwargs ########### + openai_params = [ + "functions", + "function_call", + "temperature", + "temperature", + "top_p", + "n", + "stream", + "stream_options", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "request_timeout", + "api_base", + "api_version", + "api_key", + "deployment_id", + "organization", + "base_url", + "default_headers", + "timeout", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", + "logprobs", + "top_logprobs", + "extra_headers", + ] + litellm_params = [ + "metadata", + "acompletion", + "atext_completion", + "text_completion", + "caching", + "mock_response", + "api_key", + "api_version", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "retry_policy", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "max_parallel_requests", + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_second", + "output_cost_per_second", + "hf_model_name", + "model_info", + "proxy_server_request", + "preset_cache_key", + "caching_groups", + "ttl", + "cache", + "no-log", + "base_model", + "stream_timeout", + "supports_system_message", + "region_name", + "allowed_model_region", + ] + default_params = openai_params + litellm_params + non_default_params = { + k: v for k, v in kwargs.items() if k not in default_params + } # model-specific params - pass them straight to the model/provider + + ### TIMEOUT LOGIC ### + timeout = timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + + try: + if base_url is not None: + api_base = base_url + if max_retries is not None: # openai allows openai.OpenAI(max_retries=3) + num_retries = max_retries + logging = litellm_logging_obj + fallbacks = fallbacks or litellm.model_fallbacks + if fallbacks is not None: + return completion_with_fallbacks(**args) + if model_list is not None: + deployments = [ + m["litellm_params"] for m in model_list if m["model_name"] == model + ] + return batch_completion_models(deployments=deployments, **args) + if litellm.model_alias_map and model in litellm.model_alias_map: + model = litellm.model_alias_map[ + model + ] # update the model to the actual value if an alias has been passed in + model_response = ModelResponse() + setattr(model_response, "usage", litellm.Usage()) + if ( + kwargs.get("azure", False) == True + ): # don't remove flag check, to remain backwards compatible for repos like Codium + custom_llm_provider = "azure" + if deployment_id != None: # azure llms + model = deployment_id + custom_llm_provider = "azure" + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + api_key=api_key, + ) + if model_response is not None and hasattr(model_response, "_hidden_params"): + model_response._hidden_params["custom_llm_provider"] = custom_llm_provider + model_response._hidden_params["region_name"] = kwargs.get( + "aws_region_name", None + ) # support region-based pricing for bedrock + + ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### + if input_cost_per_token is not None and output_cost_per_token is not None: + print_verbose(f"Registering model={model} in model cost map") + litellm.register_model( + { + f"{custom_llm_provider}/{model}": { + "input_cost_per_token": input_cost_per_token, + "output_cost_per_token": output_cost_per_token, + "litellm_provider": custom_llm_provider, + }, + model: { + "input_cost_per_token": input_cost_per_token, + "output_cost_per_token": output_cost_per_token, + "litellm_provider": custom_llm_provider, + }, + } + ) + elif ( + input_cost_per_second is not None + ): # time based pricing just needs cost in place + output_cost_per_second = output_cost_per_second + litellm.register_model( + { + f"{custom_llm_provider}/{model}": { + "input_cost_per_second": input_cost_per_second, + "output_cost_per_second": output_cost_per_second, + "litellm_provider": custom_llm_provider, + }, + model: { + "input_cost_per_second": input_cost_per_second, + "output_cost_per_second": output_cost_per_second, + "litellm_provider": custom_llm_provider, + }, + } + ) + ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### + custom_prompt_dict = {} # type: ignore + if ( + initial_prompt_value + or roles + or final_prompt_value + or bos_token + or eos_token + ): + custom_prompt_dict = {model: {}} + if initial_prompt_value: + custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value + if roles: + custom_prompt_dict[model]["roles"] = roles + if final_prompt_value: + custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value + if bos_token: + custom_prompt_dict[model]["bos_token"] = bos_token + if eos_token: + custom_prompt_dict[model]["eos_token"] = eos_token + + if ( + supports_system_message is not None + and isinstance(supports_system_message, bool) + and supports_system_message == False + ): + messages = map_system_message_pt(messages=messages) + model_api_key = get_api_key( + llm_provider=custom_llm_provider, dynamic_api_key=api_key + ) # get the api key from the environment if required for the model + + if dynamic_api_key is not None: + api_key = dynamic_api_key + # check if user passed in any of the OpenAI optional params + optional_params = get_optional_params( + functions=functions, + function_call=function_call, + temperature=temperature, + top_p=top_p, + n=n, + stream=stream, + stream_options=stream_options, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + user=user, + # params to identify the model + model=model, + custom_llm_provider=custom_llm_provider, + response_format=response_format, + seed=seed, + tools=tools, + tool_choice=tool_choice, + max_retries=max_retries, + logprobs=logprobs, + top_logprobs=top_logprobs, + extra_headers=extra_headers, + **non_default_params, + ) + + if litellm.add_function_to_prompt and optional_params.get( + "functions_unsupported_model", None + ): # if user opts to add it to prompt, when API doesn't support function calling + functions_unsupported_model = optional_params.pop( + "functions_unsupported_model" + ) + messages = function_call_prompt( + messages=messages, functions=functions_unsupported_model + ) + + # For logging - save the values of the litellm-specific params passed in + litellm_params = get_litellm_params( + acompletion=acompletion, + api_key=api_key, + force_timeout=force_timeout, + logger_fn=logger_fn, + verbose=verbose, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + litellm_call_id=kwargs.get("litellm_call_id", None), + model_alias_map=litellm.model_alias_map, + completion_call_id=id, + metadata=metadata, + model_info=model_info, + proxy_server_request=proxy_server_request, + preset_cache_key=preset_cache_key, + no_log=no_log, + ) + logging.update_environment_variables( + model=model, + user=user, + optional_params=optional_params, + litellm_params=litellm_params, + ) + if mock_response: + return mock_completion( + model, + messages, + stream=stream, + mock_response=mock_response, + logging=logging, + acompletion=acompletion, + ) + if custom_llm_provider == "azure": + # azure configs + api_type = get_secret("AZURE_API_TYPE") or "azure" + + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") + + api_version = ( + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) + + azure_ad_token = optional_params.get("extra_body", {}).pop( + "azure_ad_token", None + ) or get_secret("AZURE_AD_TOKEN") + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.AzureOpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## COMPLETION CALL + response = azure_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + api_key=api_key, + api_base=api_base, + api_version=api_version, + api_type=api_type, + azure_ad_token=azure_ad_token, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, # type: ignore + client=client, # pass AsyncAzureOpenAI, AzureOpenAI client + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={ + "headers": headers, + "api_version": api_version, + "api_base": api_base, + }, + ) + elif custom_llm_provider == "azure_text": + # azure configs + api_type = get_secret("AZURE_API_TYPE") or "azure" + + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") + + api_version = ( + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) + + azure_ad_token = optional_params.get("extra_body", {}).pop( + "azure_ad_token", None + ) or get_secret("AZURE_AD_TOKEN") + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.AzureOpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## COMPLETION CALL + response = azure_text_completions.completion( + model=model, + messages=messages, + headers=headers, + api_key=api_key, + api_base=api_base, + api_version=api_version, + api_type=api_type, + azure_ad_token=azure_ad_token, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, + client=client, # pass AsyncAzureOpenAI, AzureOpenAI client + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={ + "headers": headers, + "api_version": api_version, + "api_base": api_base, + }, + ) + elif ( + model in litellm.open_ai_chat_completion_models + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "deepinfra" + or custom_llm_provider == "perplexity" + or custom_llm_provider == "groq" + or custom_llm_provider == "deepseek" + or custom_llm_provider == "anyscale" + or custom_llm_provider == "mistral" + or custom_llm_provider == "openai" + or custom_llm_provider == "together_ai" + or custom_llm_provider in litellm.openai_compatible_providers + or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo + ): # allow user to make an openai call with a custom base + # note: if a user sets a custom base - we should ensure this works + # allow for the setting of dynamic and stateful api-bases + api_base = ( + api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or get_secret("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + openai.organization = ( + organization + or litellm.organization + or get_secret("OPENAI_ORGANIZATION") + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or get_secret("OPENAI_API_KEY") + ) + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.OpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## COMPLETION CALL + try: + response = openai_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + custom_prompt_dict=custom_prompt_dict, + client=client, # pass AsyncOpenAI, OpenAI client + organization=organization, + custom_llm_provider=custom_llm_provider, + ) + except Exception as e: + ## LOGGING - log the original exception returned + logging.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"headers": headers}, + ) + raise e + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={"headers": headers}, + ) + elif ( + custom_llm_provider == "text-completion-openai" + or "ft:babbage-002" in model + or "ft:davinci-002" in model # support for finetuned completion models + ): + openai.api_type = "openai" + + api_base = ( + api_base + or litellm.api_base + or get_secret("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + + openai.api_version = None + # set API KEY + + api_key = ( + api_key + or litellm.api_key + or litellm.openai_key + or get_secret("OPENAI_API_KEY") + ) + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.OpenAITextCompletionConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + if litellm.organization: + openai.organization = litellm.organization + + if ( + len(messages) > 0 + and "content" in messages[0] + and type(messages[0]["content"]) == list + ): + # text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content'] + # https://platform.openai.com/docs/api-reference/completions/create + prompt = messages[0]["content"] + else: + prompt = " ".join([message["content"] for message in messages]) # type: ignore + + ## COMPLETION CALL + _response = openai_text_completions.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + client=client, # pass AsyncOpenAI, OpenAI client + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + ) + + if ( + optional_params.get("stream", False) == False + and acompletion == False + and text_completion == False + ): + # convert to chat completion response + _response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( + response_object=_response, model_response_object=model_response + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=_response, + additional_args={"headers": headers}, + ) + response = _response + elif ( + "replicate" in model + or custom_llm_provider == "replicate" + or model in litellm.replicate_models + ): + # Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN") + replicate_key = None + replicate_key = ( + api_key + or litellm.replicate_key + or litellm.api_key + or get_secret("REPLICATE_API_KEY") + or get_secret("REPLICATE_API_TOKEN") + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("REPLICATE_API_BASE") + or "https://api.replicate.com/v1" + ) + + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + + model_response = replicate.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=replicate_key, + logging_obj=logging, + custom_prompt_dict=custom_prompt_dict, + ) + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=replicate_key, + original_response=model_response, + ) + + response = model_response + + elif custom_llm_provider == "anthropic": + api_key = ( + api_key + or litellm.anthropic_key + or litellm.api_key + or os.environ.get("ANTHROPIC_API_KEY") + ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + + if (model == "claude-2") or (model == "claude-instant-1"): + # call anthropic /completion, only use this route for claude-2, claude-instant-1 + api_base = ( + api_base + or litellm.api_base + or get_secret("ANTHROPIC_API_BASE") + or "https://api.anthropic.com/v1/complete" + ) + response = anthropic_text_completions.completion( + model=model, + messages=messages, + api_base=api_base, + acompletion=acompletion, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=api_key, + logging_obj=logging, + headers=headers, + ) + else: + # call /messages + # default route for all anthropic models + api_base = ( + api_base + or litellm.api_base + or get_secret("ANTHROPIC_API_BASE") + or "https://api.anthropic.com/v1/messages" + ) + response = anthropic_chat_completions.completion( + model=model, + messages=messages, + api_base=api_base, + acompletion=acompletion, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=api_key, + logging_obj=logging, + headers=headers, + ) + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + ) + response = response + elif custom_llm_provider == "nlp_cloud": + nlp_cloud_key = ( + api_key + or litellm.nlp_cloud_key + or get_secret("NLP_CLOUD_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("NLP_CLOUD_API_BASE") + or "https://api.nlpcloud.io/v1/gpu/" + ) + + response = nlp_cloud.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=nlp_cloud_key, + logging_obj=logging, + ) + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + response, + model, + custom_llm_provider="nlp_cloud", + logging_obj=logging, + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + ) + + response = response + elif custom_llm_provider == "aleph_alpha": + aleph_alpha_key = ( + api_key + or litellm.aleph_alpha_key + or get_secret("ALEPH_ALPHA_API_KEY") + or get_secret("ALEPHALPHA_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("ALEPH_ALPHA_API_BASE") + or "https://api.aleph-alpha.com/complete" + ) + + model_response = aleph_alpha.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + default_max_tokens_to_sample=litellm.max_tokens, + api_key=aleph_alpha_key, + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + ) + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="aleph_alpha", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "cohere": + cohere_key = ( + api_key + or litellm.cohere_key + or get_secret("COHERE_API_KEY") + or get_secret("CO_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("COHERE_API_BASE") + or "https://api.cohere.ai/v1/generate" + ) + + model_response = cohere.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=cohere_key, + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + ) + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="cohere", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "cohere_chat": + cohere_key = ( + api_key + or litellm.cohere_key + or get_secret("COHERE_API_KEY") + or get_secret("CO_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("COHERE_API_BASE") + or "https://api.cohere.ai/v1/chat" + ) + + model_response = cohere_chat.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=cohere_key, + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + ) + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="cohere_chat", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "maritalk": + maritalk_key = ( + api_key + or litellm.maritalk_key + or get_secret("MARITALK_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("MARITALK_API_BASE") + or "https://chat.maritaca.ai/api/chat/inference" + ) + + model_response = maritalk.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=maritalk_key, + logging_obj=logging, + ) + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="maritalk", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "huggingface": + custom_llm_provider = "huggingface" + huggingface_key = ( + api_key + or litellm.huggingface_key + or os.environ.get("HF_TOKEN") + or os.environ.get("HUGGINGFACE_API_KEY") + or litellm.api_key + ) + hf_headers = headers or litellm.headers + + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + model_response = huggingface.completion( + model=model, + messages=messages, + api_base=api_base, # type: ignore + headers=hf_headers, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=huggingface_key, + acompletion=acompletion, + logging_obj=logging, + custom_prompt_dict=custom_prompt_dict, + timeout=timeout, # type: ignore + ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion is False + ): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="huggingface", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "oobabooga": + custom_llm_provider = "oobabooga" + model_response = oobabooga.completion( + model=model, + messages=messages, + model_response=model_response, + api_base=api_base, # type: ignore + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + api_key=None, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + ) + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="oobabooga", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "openrouter": + api_base = api_base or litellm.api_base or "https://openrouter.ai/api/v1" + + api_key = ( + api_key + or litellm.api_key + or litellm.openrouter_key + or get_secret("OPENROUTER_API_KEY") + or get_secret("OR_API_KEY") + ) + + openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai" + + openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM" + + headers = ( + headers + or litellm.headers + or { + "HTTP-Referer": openrouter_site_url, + "X-Title": openrouter_app_name, + } + ) + + ## Load Config + config = openrouter.OpenrouterConfig.get_config() + for k, v in config.items(): + if k == "extra_body": + # we use openai 'extra_body' to pass openrouter specific params - transforms, route, models + if "extra_body" in optional_params: + optional_params[k].update(v) + else: + optional_params[k] = v + elif k not in optional_params: + optional_params[k] = v + + data = {"model": model, "messages": messages, **optional_params} + + ## COMPLETION CALL + response = openai_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + api_key=api_key, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, # type: ignore + ) + ## LOGGING + logging.post_call( + input=messages, api_key=openai.api_key, original_response=response + ) + elif ( + custom_llm_provider == "together_ai" + or ("togethercomputer" in model) + or (model in litellm.together_ai_models) + ): + """ + Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility + """ + custom_llm_provider = "together_ai" + together_ai_key = ( + api_key + or litellm.togetherai_api_key + or get_secret("TOGETHER_AI_TOKEN") + or get_secret("TOGETHERAI_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("TOGETHERAI_API_BASE") + or "https://api.together.xyz/inference" + ) + + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + + model_response = together_ai.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=together_ai_key, + logging_obj=logging, + custom_prompt_dict=custom_prompt_dict, + ) + if ( + "stream_tokens" in optional_params + and optional_params["stream_tokens"] == True + ): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="together_ai", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "palm": + palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key + + # palm does not support streaming as yet :( + model_response = palm.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=palm_api_key, + logging_obj=logging, + ) + # fake palm streaming + if "stream" in optional_params and optional_params["stream"] == True: + # fake streaming for palm + resp_string = model_response["choices"][0]["message"]["content"] + response = CustomStreamWrapper( + resp_string, model, custom_llm_provider="palm", logging_obj=logging + ) + return response + response = model_response + elif custom_llm_provider == "gemini": + gemini_api_key = ( + api_key + or get_secret("GEMINI_API_KEY") + or get_secret("PALM_API_KEY") # older palm api key should also work + or litellm.api_key + ) + + # palm does not support streaming as yet :( + model_response = gemini.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=gemini_api_key, + logging_obj=logging, + acompletion=acompletion, + custom_prompt_dict=custom_prompt_dict, + ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion == False + ): + response = CustomStreamWrapper( + iter(model_response), + model, + custom_llm_provider="gemini", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "vertex_ai": + vertex_ai_project = ( + optional_params.pop("vertex_project", None) + or optional_params.pop("vertex_ai_project", None) + or litellm.vertex_project + or get_secret("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + optional_params.pop("vertex_location", None) + or optional_params.pop("vertex_ai_location", None) + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + ) + vertex_credentials = ( + optional_params.pop("vertex_credentials", None) + or optional_params.pop("vertex_ai_credentials", None) + or get_secret("VERTEXAI_CREDENTIALS") + ) + new_params = deepcopy(optional_params) + if "claude-3" in model: + model_response = vertex_ai_anthropic.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + ) + else: + model_response = vertex_ai.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion == False + ): + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="vertex_ai", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "predibase": + tenant_id = ( + optional_params.pop("tenant_id", None) + or optional_params.pop("predibase_tenant_id", None) + or litellm.predibase_tenant_id + or get_secret("PREDIBASE_TENANT_ID") + ) + + api_base = ( + optional_params.pop("api_base", None) + or optional_params.pop("base_url", None) + or litellm.api_base + or get_secret("PREDIBASE_API_BASE") + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.predibase_key + or get_secret("PREDIBASE_API_KEY") + ) + + model_response = predibase_chat_completions.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + acompletion=acompletion, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + api_key=api_key, + tenant_id=tenant_id, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion == False + ): + return response + response = model_response + elif custom_llm_provider == "ai21": + custom_llm_provider = "ai21" + ai21_key = ( + api_key + or litellm.ai21_key + or os.environ.get("AI21_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("AI21_API_BASE") + or "https://api.ai21.com/studio/v1/" + ) + + model_response = ai21.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=ai21_key, + logging_obj=logging, + ) + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="ai21", + logging_obj=logging, + ) + return response + + ## RESPONSE OBJECT + response = model_response + elif custom_llm_provider == "sagemaker": + # boto3 reads keys from .env + model_response = sagemaker.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + custom_prompt_dict=custom_prompt_dict, + hf_model_name=hf_model_name, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + acompletion=acompletion, + ) + if ( + "stream" in optional_params and optional_params["stream"] == True + ): ## [BETA] + print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER") + from .llms.sagemaker import TokenIterator + + tokenIterator = TokenIterator(model_response, acompletion=acompletion) + response = CustomStreamWrapper( + completion_stream=tokenIterator, + model=model, + custom_llm_provider="sagemaker", + logging_obj=logging, + ) + ## LOGGING + logging.post_call( + input=messages, + api_key=None, + original_response=response, + ) + return response + + ## RESPONSE OBJECT + response = model_response + elif custom_llm_provider == "bedrock": + # boto3 reads keys from .env + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + response = bedrock.completion( + model=model, + messages=messages, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + and not isinstance(response, CustomStreamWrapper) + ): + # don't try to access stream object, + if "ai21" in model: + response = CustomStreamWrapper( + response, + model, + custom_llm_provider="bedrock", + logging_obj=logging, + ) + else: + response = CustomStreamWrapper( + iter(response), + model, + custom_llm_provider="bedrock", + logging_obj=logging, + ) + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=None, + original_response=response, + ) + + ## RESPONSE OBJECT + response = response + elif custom_llm_provider == "watsonx": + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + response = watsonx.IBMWatsonXAI().completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, # type: ignore + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + timeout=timeout, # type: ignore + ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + and not isinstance(response, CustomStreamWrapper) + ): + # don't try to access stream object, + response = CustomStreamWrapper( + iter(response), + model, + custom_llm_provider="watsonx", + logging_obj=logging, + ) + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=None, + original_response=response, + ) + ## RESPONSE OBJECT + response = response + elif custom_llm_provider == "vllm": + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + model_response = vllm.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + ) + + if ( + "stream" in optional_params and optional_params["stream"] == True + ): ## [BETA] + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="vllm", + logging_obj=logging, + ) + return response + + ## RESPONSE OBJECT + response = model_response + elif custom_llm_provider == "ollama": + api_base = ( + litellm.api_base + or api_base + or get_secret("OLLAMA_API_BASE") + or "http://localhost:11434" + ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) + else: + prompt = prompt_factory( + model=model, + messages=messages, + custom_llm_provider=custom_llm_provider, + ) + if isinstance(prompt, dict): + # for multimode models - ollama/llava prompt_factory returns a dict { + # "prompt": prompt, + # "images": images + # } + prompt, images = prompt["prompt"], prompt["images"] + optional_params["images"] = images + + ## LOGGING + generator = ollama.get_ollama_response( + api_base, + model, + prompt, + optional_params, + logging_obj=logging, + acompletion=acompletion, + model_response=model_response, + encoding=encoding, + ) + if acompletion is True or optional_params.get("stream", False) == True: + return generator + + response = generator + elif custom_llm_provider == "ollama_chat": + api_base = ( + litellm.api_base + or api_base + or get_secret("OLLAMA_API_BASE") + or "http://localhost:11434" + ) + + api_key = ( + api_key + or litellm.ollama_key + or os.environ.get("OLLAMA_API_KEY") + or litellm.api_key + ) + ## LOGGING + generator = ollama_chat.get_ollama_response( + api_base, + api_key, + model, + messages, + optional_params, + logging_obj=logging, + acompletion=acompletion, + model_response=model_response, + encoding=encoding, + ) + if acompletion is True or optional_params.get("stream", False) == True: + return generator + + response = generator + elif custom_llm_provider == "cloudflare": + api_key = ( + api_key + or litellm.cloudflare_api_key + or litellm.api_key + or get_secret("CLOUDFLARE_API_KEY") + ) + account_id = get_secret("CLOUDFLARE_ACCOUNT_ID") + api_base = ( + api_base + or litellm.api_base + or get_secret("CLOUDFLARE_API_BASE") + or f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/" + ) + + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + response = cloudflare.completion( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=api_key, + logging_obj=logging, + ) + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + response, + model, + custom_llm_provider="cloudflare", + logging_obj=logging, + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + ) + response = response + elif ( + custom_llm_provider == "baseten" + or litellm.api_base == "https://app.baseten.co" + ): + custom_llm_provider = "baseten" + baseten_key = ( + api_key + or litellm.baseten_key + or os.environ.get("BASETEN_API_KEY") + or litellm.api_key + ) + + model_response = baseten.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=baseten_key, + logging_obj=logging, + ) + if inspect.isgenerator(model_response) or ( + "stream" in optional_params and optional_params["stream"] == True + ): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="baseten", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "petals" or model in litellm.petals_models: + api_base = api_base or litellm.api_base + + custom_llm_provider = "petals" + stream = optional_params.pop("stream", False) + model_response = petals.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + ) + if stream == True: ## [BETA] + # Fake streaming for petals + resp_string = model_response["choices"][0]["message"]["content"] + response = CustomStreamWrapper( + resp_string, + model, + custom_llm_provider="petals", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "custom": + import requests + + url = litellm.api_base or api_base or "" + if url == None or url == "": + raise ValueError( + "api_base not set. Set api_base or litellm.api_base for custom endpoints" + ) + + """ + assume input to custom LLM api bases follow this format: + resp = requests.post( + api_base, + json={ + 'model': 'meta-llama/Llama-2-13b-hf', # model name + 'params': { + 'prompt': ["The capital of France is P"], + 'max_tokens': 32, + 'temperature': 0.7, + 'top_p': 1.0, + 'top_k': 40, + } + } + ) + + """ + prompt = " ".join([message["content"] for message in messages]) # type: ignore + resp = requests.post( + url, + json={ + "model": model, + "params": { + "prompt": [prompt], + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "top_k": kwargs.get("top_k", 40), + }, + }, + ) + response_json = resp.json() + """ + assume all responses from custom api_bases of this format: + { + 'data': [ + { + 'prompt': 'The capital of France is P', + 'output': ['The capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France'], + 'params': {'temperature': 0.7, 'top_k': 40, 'top_p': 1}}], + 'message': 'ok' + } + ] + } + """ + string_response = response_json["data"][0]["output"][0] + ## RESPONSE OBJECT + model_response["choices"][0]["message"]["content"] = string_response + model_response["created"] = int(time.time()) + model_response["model"] = model + response = model_response + else: + raise ValueError( + f"Unable to map your input to a model. Check your input - {args}" + ) + return response + except Exception as e: + ## Map to OpenAI Exception +> raise exception_type( + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + extra_kwargs=kwargs, + ) + +../main.py:2287: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +model = 'llama-3-8b-instruct', original_exception = KeyError('stream') +custom_llm_provider = 'predibase' +completion_kwargs = {'acompletion': False, 'api_base': None, 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'api_version': None, ...} +extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_id': 'cf0ea464-1b45-4473-8e55-6bf6809df7a7', 'litellm_logging_obj': , 'tenant_id': 'c4768f95'} + + def exception_type( + model, + original_exception, + custom_llm_provider, + completion_kwargs={}, + extra_kwargs={}, + ): + global user_logger_fn, liteDebuggerClient + exception_mapping_worked = False + if litellm.suppress_debug_info is False: + print() # noqa + print( # noqa + "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" # noqa + ) # noqa + print( # noqa + "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa + ) # noqa + print() # noqa + try: + if model: + error_str = str(original_exception) + if isinstance(original_exception, BaseException): + exception_type = type(original_exception).__name__ + else: + exception_type = "" + + ################################################################################ + # Common Extra information needed for all providers + # We pass num retries, api_base, vertex_deployment etc to the exception here + ################################################################################ + + _api_base = litellm.get_api_base(model=model, optional_params=extra_kwargs) + messages = litellm.get_first_chars_messages(kwargs=completion_kwargs) + _vertex_project = extra_kwargs.get("vertex_project") + _vertex_location = extra_kwargs.get("vertex_location") + _metadata = extra_kwargs.get("metadata", {}) or {} + _model_group = _metadata.get("model_group") + _deployment = _metadata.get("deployment") + extra_information = f"\nModel: {model}" + if _api_base: + extra_information += f"\nAPI Base: {_api_base}" + if messages and len(messages) > 0: + extra_information += f"\nMessages: {messages}" + + if _model_group is not None: + extra_information += f"\nmodel_group: {_model_group}\n" + if _deployment is not None: + extra_information += f"\ndeployment: {_deployment}\n" + if _vertex_project is not None: + extra_information += f"\nvertex_project: {_vertex_project}\n" + if _vertex_location is not None: + extra_information += f"\nvertex_location: {_vertex_location}\n" + + # on litellm proxy add key name + team to exceptions + extra_information = _add_key_name_and_team_to_alert( + request_info=extra_information, metadata=_metadata + ) + + ################################################################################ + # End of Common Extra information Needed for all providers + ################################################################################ + + ################################################################################ + #################### Start of Provider Exception mapping #################### + ################################################################################ + + if "Request Timeout Error" in error_str or "Request timed out" in error_str: + exception_mapping_worked = True + raise Timeout( + message=f"APITimeoutError - Request timed out. {extra_information} \n error_str: {error_str}", + model=model, + llm_provider=custom_llm_provider, + ) + + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "text-completion-openai" + or custom_llm_provider == "custom_openai" + or custom_llm_provider in litellm.openai_compatible_providers + ): + # custom_llm_provider is openai, make it OpenAI + if hasattr(original_exception, "message"): + message = original_exception.message + else: + message = str(original_exception) + if message is not None and isinstance(message, str): + message = message.replace("OPENAI", custom_llm_provider.upper()) + message = message.replace("openai", custom_llm_provider) + message = message.replace("OpenAI", custom_llm_provider) + if custom_llm_provider == "openai": + exception_provider = "OpenAI" + "Exception" + else: + exception_provider = ( + custom_llm_provider[0].upper() + + custom_llm_provider[1:] + + "Exception" + ) + + if ( + "This model's maximum context length is" in error_str + or "Request too large" in error_str + ): + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif ( + "invalid_request_error" in error_str + and "model_not_found" in error_str + ): + exception_mapping_worked = True + raise NotFoundError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif ( + "invalid_request_error" in error_str + and "content_policy_violation" in error_str + ): + exception_mapping_worked = True + raise ContentPolicyViolationError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif ( + "invalid_request_error" in error_str + and "Incorrect API key provided" not in error_str + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif ( + "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" + in error_str + ): + exception_mapping_worked = True + raise AuthenticationError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif "Mistral API raised a streaming error" in error_str: + exception_mapping_worked = True + _request = httpx.Request( + method="POST", url="https://api.openai.com/v1" + ) + raise APIError( + status_code=500, + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + request=_request, + ) + elif hasattr(original_exception, "status_code"): + exception_mapping_worked = True + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 404: + exception_mapping_worked = True + raise NotFoundError( + message=f"{exception_provider} - {message} {extra_information}", + model=model, + llm_provider=custom_llm_provider, + response=original_exception.response, + ) + elif original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"{exception_provider} - {message} {extra_information}", + model=model, + llm_provider=custom_llm_provider, + ) + elif original_exception.status_code == 422: + exception_mapping_worked = True + raise BadRequestError( + message=f"{exception_provider} - {message} {extra_information}", + model=model, + llm_provider=custom_llm_provider, + response=original_exception.response, + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"{exception_provider} - {message} {extra_information}", + model=model, + llm_provider=custom_llm_provider, + response=original_exception.response, + ) + elif original_exception.status_code == 503: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"{exception_provider} - {message} {extra_information}", + model=model, + llm_provider=custom_llm_provider, + response=original_exception.response, + ) + elif original_exception.status_code == 504: # gateway timeout error + exception_mapping_worked = True + raise Timeout( + message=f"{exception_provider} - {message} {extra_information}", + model=model, + llm_provider=custom_llm_provider, + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + request=original_exception.request, + ) + else: + # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors + raise APIConnectionError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + request=httpx.Request( + method="POST", url="https://api.openai.com/v1/" + ), + ) + elif custom_llm_provider == "anthropic": # one of the anthropics + if hasattr(original_exception, "message"): + if ( + "prompt is too long" in original_exception.message + or "prompt: length" in original_exception.message + ): + exception_mapping_worked = True + raise ContextWindowExceededError( + message=original_exception.message, + model=model, + llm_provider="anthropic", + response=original_exception.response, + ) + if "Invalid API Key" in original_exception.message: + exception_mapping_worked = True + raise AuthenticationError( + message=original_exception.message, + model=model, + llm_provider="anthropic", + response=original_exception.response, + ) + if hasattr(original_exception, "status_code"): + print_verbose(f"status_code: {original_exception.status_code}") + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"AnthropicException - {original_exception.message}", + llm_provider="anthropic", + model=model, + response=original_exception.response, + ) + elif ( + original_exception.status_code == 400 + or original_exception.status_code == 413 + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"AnthropicException - {original_exception.message}", + model=model, + llm_provider="anthropic", + response=original_exception.response, + ) + elif original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"AnthropicException - {original_exception.message}", + model=model, + llm_provider="anthropic", + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"AnthropicException - {original_exception.message}", + llm_provider="anthropic", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 500: + exception_mapping_worked = True + raise APIError( + status_code=500, + message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.", + llm_provider="anthropic", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "replicate": + if "Incorrect authentication token" in error_str: + exception_mapping_worked = True + raise AuthenticationError( + message=f"ReplicateException - {error_str}", + llm_provider="replicate", + model=model, + response=original_exception.response, + ) + elif "input is too long" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"ReplicateException - {error_str}", + model=model, + llm_provider="replicate", + response=original_exception.response, + ) + elif exception_type == "ModelError": + exception_mapping_worked = True + raise BadRequestError( + message=f"ReplicateException - {error_str}", + model=model, + llm_provider="replicate", + response=original_exception.response, + ) + elif "Request was throttled" in error_str: + exception_mapping_worked = True + raise RateLimitError( + message=f"ReplicateException - {error_str}", + llm_provider="replicate", + model=model, + response=original_exception.response, + ) + elif hasattr(original_exception, "status_code"): + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"ReplicateException - {original_exception.message}", + llm_provider="replicate", + model=model, + response=original_exception.response, + ) + elif ( + original_exception.status_code == 400 + or original_exception.status_code == 422 + or original_exception.status_code == 413 + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"ReplicateException - {original_exception.message}", + model=model, + llm_provider="replicate", + response=original_exception.response, + ) + elif original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"ReplicateException - {original_exception.message}", + model=model, + llm_provider="replicate", + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"ReplicateException - {original_exception.message}", + llm_provider="replicate", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 500: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"ReplicateException - {original_exception.message}", + llm_provider="replicate", + model=model, + response=original_exception.response, + ) + exception_mapping_worked = True + raise APIError( + status_code=500, + message=f"ReplicateException - {str(original_exception)}", + llm_provider="replicate", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "bedrock": + if ( + "too many tokens" in error_str + or "expected maxLength:" in error_str + or "Input is too long" in error_str + or "prompt: length: 1.." in error_str + or "Too many input tokens" in error_str + ): + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"BedrockException: Context Window Error - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, + ) + if "Malformed input request" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"BedrockException - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, + ) + if ( + "Unable to locate credentials" in error_str + or "The security token included in the request is invalid" + in error_str + ): + exception_mapping_worked = True + raise AuthenticationError( + message=f"BedrockException Invalid Authentication - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, + ) + if "AccessDeniedException" in error_str: + exception_mapping_worked = True + raise PermissionDeniedError( + message=f"BedrockException PermissionDeniedError - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, + ) + if ( + "throttlingException" in error_str + or "ThrottlingException" in error_str + ): + exception_mapping_worked = True + raise RateLimitError( + message=f"BedrockException: Rate Limit Error - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, + ) + if "Connect timeout on endpoint URL" in error_str: + exception_mapping_worked = True + raise Timeout( + message=f"BedrockException: Timeout Error - {error_str}", + model=model, + llm_provider="bedrock", + ) + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 500: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"BedrockException - {original_exception.message}", + llm_provider="bedrock", + model=model, + response=httpx.Response( + status_code=500, + request=httpx.Request( + method="POST", url="https://api.openai.com/v1/" + ), + ), + ) + elif original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"BedrockException - {original_exception.message}", + llm_provider="bedrock", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 400: + exception_mapping_worked = True + raise BadRequestError( + message=f"BedrockException - {original_exception.message}", + llm_provider="bedrock", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 404: + exception_mapping_worked = True + raise NotFoundError( + message=f"BedrockException - {original_exception.message}", + llm_provider="bedrock", + model=model, + response=original_exception.response, + ) + elif custom_llm_provider == "sagemaker": + if "Unable to locate credentials" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"SagemakerException - {error_str}", + model=model, + llm_provider="sagemaker", + response=original_exception.response, + ) + elif ( + "Input validation error: `best_of` must be > 0 and <= 2" + in error_str + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", + model=model, + llm_provider="sagemaker", + response=original_exception.response, + ) + elif ( + "`inputs` tokens + `max_new_tokens` must be <=" in error_str + or "instance type with more CPU capacity or memory" in error_str + ): + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"SagemakerException - {error_str}", + model=model, + llm_provider="sagemaker", + response=original_exception.response, + ) + elif custom_llm_provider == "vertex_ai": + if ( + "Vertex AI API has not been used in project" in error_str + or "Unable to find your project" in error_str + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"VertexAIException - {error_str} {extra_information}", + model=model, + llm_provider="vertex_ai", + response=original_exception.response, + ) + elif ( + "None Unknown Error." in error_str + or "Content has no parts." in error_str + ): + exception_mapping_worked = True + raise APIError( + message=f"VertexAIException - {error_str} {extra_information}", + status_code=500, + model=model, + llm_provider="vertex_ai", + request=original_exception.request, + ) + elif "403" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"VertexAIException - {error_str} {extra_information}", + model=model, + llm_provider="vertex_ai", + response=original_exception.response, + ) + elif "The response was blocked." in error_str: + exception_mapping_worked = True + raise UnprocessableEntityError( + message=f"VertexAIException - {error_str} {extra_information}", + model=model, + llm_provider="vertex_ai", + response=httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ), + ) + elif ( + "429 Quota exceeded" in error_str + or "IndexError: list index out of range" in error_str + or "429 Unable to submit request because the service is temporarily out of capacity." + in error_str + ): + exception_mapping_worked = True + raise RateLimitError( + message=f"VertexAIException - {error_str} {extra_information}", + model=model, + llm_provider="vertex_ai", + response=httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ), + ) + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 400: + exception_mapping_worked = True + raise BadRequestError( + message=f"VertexAIException - {error_str} {extra_information}", + model=model, + llm_provider="vertex_ai", + response=original_exception.response, + ) + if original_exception.status_code == 500: + exception_mapping_worked = True + raise APIError( + message=f"VertexAIException - {error_str} {extra_information}", + status_code=500, + model=model, + llm_provider="vertex_ai", + request=original_exception.request, + ) + elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": + if "503 Getting metadata" in error_str: + # auth errors look like this + # 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate. + exception_mapping_worked = True + raise BadRequestError( + message=f"GeminiException - Invalid api key", + model=model, + llm_provider="palm", + response=original_exception.response, + ) + if ( + "504 Deadline expired before operation could complete." in error_str + or "504 Deadline Exceeded" in error_str + ): + exception_mapping_worked = True + raise Timeout( + message=f"GeminiException - {original_exception.message}", + model=model, + llm_provider="palm", + ) + if "400 Request payload size exceeds" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"GeminiException - {error_str}", + model=model, + llm_provider="palm", + response=original_exception.response, + ) + if ( + "500 An internal error has occurred." in error_str + or "list index out of range" in error_str + ): + exception_mapping_worked = True + raise APIError( + status_code=getattr(original_exception, "status_code", 500), + message=f"GeminiException - {original_exception.message}", + llm_provider="palm", + model=model, + request=httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ), + ) + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 400: + exception_mapping_worked = True + raise BadRequestError( + message=f"GeminiException - {error_str}", + model=model, + llm_provider="palm", + response=original_exception.response, + ) + # Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes + elif custom_llm_provider == "cloudflare": + if "Authentication error" in error_str: + exception_mapping_worked = True + raise AuthenticationError( + message=f"Cloudflare Exception - {original_exception.message}", + llm_provider="cloudflare", + model=model, + response=original_exception.response, + ) + if "must have required property" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"Cloudflare Exception - {original_exception.message}", + llm_provider="cloudflare", + model=model, + response=original_exception.response, + ) + elif ( + custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat" + ): # Cohere + if ( + "invalid api token" in error_str + or "No API key provided." in error_str + ): + exception_mapping_worked = True + raise AuthenticationError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + response=original_exception.response, + ) + elif "too many tokens" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"CohereException - {original_exception.message}", + model=model, + llm_provider="cohere", + response=original_exception.response, + ) + elif hasattr(original_exception, "status_code"): + if ( + original_exception.status_code == 400 + or original_exception.status_code == 498 + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 500: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + response=original_exception.response, + ) + elif ( + "CohereConnectionError" in exception_type + ): # cohere seems to fire these errors when we load test it (1k+ messages / min) + exception_mapping_worked = True + raise RateLimitError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + response=original_exception.response, + ) + elif "invalid type:" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + response=original_exception.response, + ) + elif "Unexpected server error" in error_str: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + response=original_exception.response, + ) + else: + if hasattr(original_exception, "status_code"): + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + request=original_exception.request, + ) + raise original_exception + elif custom_llm_provider == "huggingface": + if "length limit exceeded" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=error_str, + model=model, + llm_provider="huggingface", + response=original_exception.response, + ) + elif "A valid user token is required" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=error_str, + llm_provider="huggingface", + model=model, + response=original_exception.response, + ) + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"HuggingfaceException - {original_exception.message}", + llm_provider="huggingface", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 400: + exception_mapping_worked = True + raise BadRequestError( + message=f"HuggingfaceException - {original_exception.message}", + model=model, + llm_provider="huggingface", + response=original_exception.response, + ) + elif original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"HuggingfaceException - {original_exception.message}", + model=model, + llm_provider="huggingface", + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"HuggingfaceException - {original_exception.message}", + llm_provider="huggingface", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 503: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"HuggingfaceException - {original_exception.message}", + llm_provider="huggingface", + model=model, + response=original_exception.response, + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"HuggingfaceException - {original_exception.message}", + llm_provider="huggingface", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "ai21": + if hasattr(original_exception, "message"): + if "Prompt has too many tokens" in original_exception.message: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"AI21Exception - {original_exception.message}", + model=model, + llm_provider="ai21", + response=original_exception.response, + ) + if "Bad or missing API token." in original_exception.message: + exception_mapping_worked = True + raise BadRequestError( + message=f"AI21Exception - {original_exception.message}", + model=model, + llm_provider="ai21", + response=original_exception.response, + ) + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"AI21Exception - {original_exception.message}", + llm_provider="ai21", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"AI21Exception - {original_exception.message}", + model=model, + llm_provider="ai21", + ) + if original_exception.status_code == 422: + exception_mapping_worked = True + raise BadRequestError( + message=f"AI21Exception - {original_exception.message}", + model=model, + llm_provider="ai21", + response=original_exception.response, + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"AI21Exception - {original_exception.message}", + llm_provider="ai21", + model=model, + response=original_exception.response, + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"AI21Exception - {original_exception.message}", + llm_provider="ai21", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "nlp_cloud": + if "detail" in error_str: + if "Input text length should not exceed" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"NLPCloudException - {error_str}", + model=model, + llm_provider="nlp_cloud", + response=original_exception.response, + ) + elif "value is not a valid" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"NLPCloudException - {error_str}", + model=model, + llm_provider="nlp_cloud", + response=original_exception.response, + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=500, + message=f"NLPCloudException - {error_str}", + model=model, + llm_provider="nlp_cloud", + request=original_exception.request, + ) + if hasattr( + original_exception, "status_code" + ): # https://docs.nlpcloud.com/?shell#errors + if ( + original_exception.status_code == 400 + or original_exception.status_code == 406 + or original_exception.status_code == 413 + or original_exception.status_code == 422 + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"NLPCloudException - {original_exception.message}", + llm_provider="nlp_cloud", + model=model, + response=original_exception.response, + ) + elif ( + original_exception.status_code == 401 + or original_exception.status_code == 403 + ): + exception_mapping_worked = True + raise AuthenticationError( + message=f"NLPCloudException - {original_exception.message}", + llm_provider="nlp_cloud", + model=model, + response=original_exception.response, + ) + elif ( + original_exception.status_code == 522 + or original_exception.status_code == 524 + ): + exception_mapping_worked = True + raise Timeout( + message=f"NLPCloudException - {original_exception.message}", + model=model, + llm_provider="nlp_cloud", + ) + elif ( + original_exception.status_code == 429 + or original_exception.status_code == 402 + ): + exception_mapping_worked = True + raise RateLimitError( + message=f"NLPCloudException - {original_exception.message}", + llm_provider="nlp_cloud", + model=model, + response=original_exception.response, + ) + elif ( + original_exception.status_code == 500 + or original_exception.status_code == 503 + ): + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"NLPCloudException - {original_exception.message}", + llm_provider="nlp_cloud", + model=model, + request=original_exception.request, + ) + elif ( + original_exception.status_code == 504 + or original_exception.status_code == 520 + ): + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"NLPCloudException - {original_exception.message}", + model=model, + llm_provider="nlp_cloud", + response=original_exception.response, + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"NLPCloudException - {original_exception.message}", + llm_provider="nlp_cloud", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "together_ai": + import json + + try: + error_response = json.loads(error_str) + except: + error_response = {"error": error_str} + if ( + "error" in error_response + and "`inputs` tokens + `max_new_tokens` must be <=" + in error_response["error"] + ): + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"TogetherAIException - {error_response['error']}", + model=model, + llm_provider="together_ai", + response=original_exception.response, + ) + elif ( + "error" in error_response + and "invalid private key" in error_response["error"] + ): + exception_mapping_worked = True + raise AuthenticationError( + message=f"TogetherAIException - {error_response['error']}", + llm_provider="together_ai", + model=model, + response=original_exception.response, + ) + elif ( + "error" in error_response + and "INVALID_ARGUMENT" in error_response["error"] + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"TogetherAIException - {error_response['error']}", + model=model, + llm_provider="together_ai", + response=original_exception.response, + ) + + elif ( + "error" in error_response + and "API key doesn't match expected format." + in error_response["error"] + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"TogetherAIException - {error_response['error']}", + model=model, + llm_provider="together_ai", + response=original_exception.response, + ) + elif ( + "error_type" in error_response + and error_response["error_type"] == "validation" + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"TogetherAIException - {error_response['error']}", + model=model, + llm_provider="together_ai", + response=original_exception.response, + ) + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"TogetherAIException - {original_exception.message}", + model=model, + llm_provider="together_ai", + ) + elif original_exception.status_code == 422: + exception_mapping_worked = True + raise BadRequestError( + message=f"TogetherAIException - {error_response['error']}", + model=model, + llm_provider="together_ai", + response=original_exception.response, + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"TogetherAIException - {original_exception.message}", + llm_provider="together_ai", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 524: + exception_mapping_worked = True + raise Timeout( + message=f"TogetherAIException - {original_exception.message}", + llm_provider="together_ai", + model=model, + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"TogetherAIException - {original_exception.message}", + llm_provider="together_ai", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "aleph_alpha": + if ( + "This is longer than the model's maximum context length" + in error_str + ): + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"AlephAlphaException - {original_exception.message}", + llm_provider="aleph_alpha", + model=model, + response=original_exception.response, + ) + elif "InvalidToken" in error_str or "No token provided" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"AlephAlphaException - {original_exception.message}", + llm_provider="aleph_alpha", + model=model, + response=original_exception.response, + ) + elif hasattr(original_exception, "status_code"): + print_verbose(f"status code: {original_exception.status_code}") + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"AlephAlphaException - {original_exception.message}", + llm_provider="aleph_alpha", + model=model, + ) + elif original_exception.status_code == 400: + exception_mapping_worked = True + raise BadRequestError( + message=f"AlephAlphaException - {original_exception.message}", + llm_provider="aleph_alpha", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"AlephAlphaException - {original_exception.message}", + llm_provider="aleph_alpha", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 500: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"AlephAlphaException - {original_exception.message}", + llm_provider="aleph_alpha", + model=model, + response=original_exception.response, + ) + raise original_exception + raise original_exception + elif ( + custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat" + ): + if isinstance(original_exception, dict): + error_str = original_exception.get("error", "") + else: + error_str = str(original_exception) + if "no such file or directory" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}", + model=model, + llm_provider="ollama", + response=original_exception.response, + ) + elif "Failed to establish a new connection" in error_str: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"OllamaException: {original_exception}", + llm_provider="ollama", + model=model, + response=original_exception.response, + ) + elif "Invalid response object from API" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"OllamaException: {original_exception}", + llm_provider="ollama", + model=model, + response=original_exception.response, + ) + elif "Read timed out" in error_str: + exception_mapping_worked = True + raise Timeout( + message=f"OllamaException: {original_exception}", + llm_provider="ollama", + model=model, + ) + elif custom_llm_provider == "vllm": + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 0: + exception_mapping_worked = True + raise APIConnectionError( + message=f"VLLMException - {original_exception.message}", + llm_provider="vllm", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "azure": + if "Internal server error" in error_str: + exception_mapping_worked = True + raise APIError( + status_code=500, + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + request=httpx.Request(method="POST", url="https://openai.com/"), + ) + elif "This model's maximum context length is" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + response=original_exception.response, + ) + elif "DeploymentNotFound" in error_str: + exception_mapping_worked = True + raise NotFoundError( + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + response=original_exception.response, + ) + elif ( + "invalid_request_error" in error_str + and "content_policy_violation" in error_str + ) or ( + "The response was filtered due to the prompt triggering Azure OpenAI's content management" + in error_str + ): + exception_mapping_worked = True + raise ContentPolicyViolationError( + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + response=original_exception.response, + ) + elif "invalid_request_error" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + response=original_exception.response, + ) + elif ( + "The api_key client option must be set either by passing api_key to the client or by setting" + in error_str + ): + exception_mapping_worked = True + raise AuthenticationError( + message=f"{exception_provider} - {original_exception.message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif hasattr(original_exception, "status_code"): + exception_mapping_worked = True + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"AzureException - {original_exception.message} {extra_information}", + model=model, + llm_provider="azure", + ) + if original_exception.status_code == 422: + exception_mapping_worked = True + raise BadRequestError( + message=f"AzureException - {original_exception.message} {extra_information}", + model=model, + llm_provider="azure", + response=original_exception.response, + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"AzureException - {original_exception.message} {extra_information}", + model=model, + llm_provider="azure", + response=original_exception.response, + ) + elif original_exception.status_code == 503: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"AzureException - {original_exception.message} {extra_information}", + model=model, + llm_provider="azure", + response=original_exception.response, + ) + elif original_exception.status_code == 504: # gateway timeout error + exception_mapping_worked = True + raise Timeout( + message=f"AzureException - {original_exception.message} {extra_information}", + model=model, + llm_provider="azure", + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + request=httpx.Request( + method="POST", url="https://openai.com/" + ), + ) + else: + # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors + raise APIConnectionError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider="azure", + model=model, + request=httpx.Request(method="POST", url="https://openai.com/"), + ) + if ( + "BadRequestError.__init__() missing 1 required positional argument: 'param'" + in str(original_exception) + ): # deal with edge-case invalid request error bug in openai-python sdk + exception_mapping_worked = True + raise BadRequestError( + message=f"{exception_provider}: This can happen due to missing AZURE_API_VERSION: {str(original_exception)}", + model=model, + llm_provider=custom_llm_provider, + response=original_exception.response, + ) + else: # ensure generic errors always return APIConnectionError= + exception_mapping_worked = True + if hasattr(original_exception, "request"): + raise APIConnectionError( + message=f"{str(original_exception)}", + llm_provider=custom_llm_provider, + model=model, + request=original_exception.request, + ) + else: + raise APIConnectionError( + message=f"{str(original_exception)}", + llm_provider=custom_llm_provider, + model=model, + request=httpx.Request( + method="POST", url="https://api.openai.com/v1/" + ), # stub the request + ) + except Exception as e: + # LOGGING + exception_logging( + logger_fn=user_logger_fn, + additional_args={ + "exception_mapping_worked": exception_mapping_worked, + "original_exception": original_exception, + }, + exception=e, + ) + ## AUTH ERROR + if isinstance(e, AuthenticationError) and ( + litellm.email or "LITELLM_EMAIL" in os.environ + ): + threading.Thread(target=get_all_keys, args=(e.llm_provider,)).start() + # don't let an error with mapping interrupt the user from receiving an error from the llm api calls + if exception_mapping_worked: +> raise e + +../utils.py:9353: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +model = 'llama-3-8b-instruct', original_exception = KeyError('stream') +custom_llm_provider = 'predibase' +completion_kwargs = {'acompletion': False, 'api_base': None, 'api_key': 'pb_Qg9YbQo7UqqHdu0ozxN_aw', 'api_version': None, ...} +extra_kwargs = {'api_base': 'https://serving.app.predibase.com', 'litellm_call_id': 'cf0ea464-1b45-4473-8e55-6bf6809df7a7', 'litellm_logging_obj': , 'tenant_id': 'c4768f95'} + + def exception_type( + model, + original_exception, + custom_llm_provider, + completion_kwargs={}, + extra_kwargs={}, + ): + global user_logger_fn, liteDebuggerClient + exception_mapping_worked = False + if litellm.suppress_debug_info is False: + print() # noqa + print( # noqa + "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" # noqa + ) # noqa + print( # noqa + "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa + ) # noqa + print() # noqa + try: + if model: + error_str = str(original_exception) + if isinstance(original_exception, BaseException): + exception_type = type(original_exception).__name__ + else: + exception_type = "" + + ################################################################################ + # Common Extra information needed for all providers + # We pass num retries, api_base, vertex_deployment etc to the exception here + ################################################################################ + + _api_base = litellm.get_api_base(model=model, optional_params=extra_kwargs) + messages = litellm.get_first_chars_messages(kwargs=completion_kwargs) + _vertex_project = extra_kwargs.get("vertex_project") + _vertex_location = extra_kwargs.get("vertex_location") + _metadata = extra_kwargs.get("metadata", {}) or {} + _model_group = _metadata.get("model_group") + _deployment = _metadata.get("deployment") + extra_information = f"\nModel: {model}" + if _api_base: + extra_information += f"\nAPI Base: {_api_base}" + if messages and len(messages) > 0: + extra_information += f"\nMessages: {messages}" + + if _model_group is not None: + extra_information += f"\nmodel_group: {_model_group}\n" + if _deployment is not None: + extra_information += f"\ndeployment: {_deployment}\n" + if _vertex_project is not None: + extra_information += f"\nvertex_project: {_vertex_project}\n" + if _vertex_location is not None: + extra_information += f"\nvertex_location: {_vertex_location}\n" + + # on litellm proxy add key name + team to exceptions + extra_information = _add_key_name_and_team_to_alert( + request_info=extra_information, metadata=_metadata + ) + + ################################################################################ + # End of Common Extra information Needed for all providers + ################################################################################ + + ################################################################################ + #################### Start of Provider Exception mapping #################### + ################################################################################ + + if "Request Timeout Error" in error_str or "Request timed out" in error_str: + exception_mapping_worked = True + raise Timeout( + message=f"APITimeoutError - Request timed out. {extra_information} \n error_str: {error_str}", + model=model, + llm_provider=custom_llm_provider, + ) + + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "text-completion-openai" + or custom_llm_provider == "custom_openai" + or custom_llm_provider in litellm.openai_compatible_providers + ): + # custom_llm_provider is openai, make it OpenAI + if hasattr(original_exception, "message"): + message = original_exception.message + else: + message = str(original_exception) + if message is not None and isinstance(message, str): + message = message.replace("OPENAI", custom_llm_provider.upper()) + message = message.replace("openai", custom_llm_provider) + message = message.replace("OpenAI", custom_llm_provider) + if custom_llm_provider == "openai": + exception_provider = "OpenAI" + "Exception" + else: + exception_provider = ( + custom_llm_provider[0].upper() + + custom_llm_provider[1:] + + "Exception" + ) + + if ( + "This model's maximum context length is" in error_str + or "Request too large" in error_str + ): + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif ( + "invalid_request_error" in error_str + and "model_not_found" in error_str + ): + exception_mapping_worked = True + raise NotFoundError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif ( + "invalid_request_error" in error_str + and "content_policy_violation" in error_str + ): + exception_mapping_worked = True + raise ContentPolicyViolationError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif ( + "invalid_request_error" in error_str + and "Incorrect API key provided" not in error_str + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif ( + "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" + in error_str + ): + exception_mapping_worked = True + raise AuthenticationError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif "Mistral API raised a streaming error" in error_str: + exception_mapping_worked = True + _request = httpx.Request( + method="POST", url="https://api.openai.com/v1" + ) + raise APIError( + status_code=500, + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + request=_request, + ) + elif hasattr(original_exception, "status_code"): + exception_mapping_worked = True + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 404: + exception_mapping_worked = True + raise NotFoundError( + message=f"{exception_provider} - {message} {extra_information}", + model=model, + llm_provider=custom_llm_provider, + response=original_exception.response, + ) + elif original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"{exception_provider} - {message} {extra_information}", + model=model, + llm_provider=custom_llm_provider, + ) + elif original_exception.status_code == 422: + exception_mapping_worked = True + raise BadRequestError( + message=f"{exception_provider} - {message} {extra_information}", + model=model, + llm_provider=custom_llm_provider, + response=original_exception.response, + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"{exception_provider} - {message} {extra_information}", + model=model, + llm_provider=custom_llm_provider, + response=original_exception.response, + ) + elif original_exception.status_code == 503: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"{exception_provider} - {message} {extra_information}", + model=model, + llm_provider=custom_llm_provider, + response=original_exception.response, + ) + elif original_exception.status_code == 504: # gateway timeout error + exception_mapping_worked = True + raise Timeout( + message=f"{exception_provider} - {message} {extra_information}", + model=model, + llm_provider=custom_llm_provider, + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + request=original_exception.request, + ) + else: + # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors + raise APIConnectionError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + request=httpx.Request( + method="POST", url="https://api.openai.com/v1/" + ), + ) + elif custom_llm_provider == "anthropic": # one of the anthropics + if hasattr(original_exception, "message"): + if ( + "prompt is too long" in original_exception.message + or "prompt: length" in original_exception.message + ): + exception_mapping_worked = True + raise ContextWindowExceededError( + message=original_exception.message, + model=model, + llm_provider="anthropic", + response=original_exception.response, + ) + if "Invalid API Key" in original_exception.message: + exception_mapping_worked = True + raise AuthenticationError( + message=original_exception.message, + model=model, + llm_provider="anthropic", + response=original_exception.response, + ) + if hasattr(original_exception, "status_code"): + print_verbose(f"status_code: {original_exception.status_code}") + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"AnthropicException - {original_exception.message}", + llm_provider="anthropic", + model=model, + response=original_exception.response, + ) + elif ( + original_exception.status_code == 400 + or original_exception.status_code == 413 + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"AnthropicException - {original_exception.message}", + model=model, + llm_provider="anthropic", + response=original_exception.response, + ) + elif original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"AnthropicException - {original_exception.message}", + model=model, + llm_provider="anthropic", + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"AnthropicException - {original_exception.message}", + llm_provider="anthropic", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 500: + exception_mapping_worked = True + raise APIError( + status_code=500, + message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.", + llm_provider="anthropic", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "replicate": + if "Incorrect authentication token" in error_str: + exception_mapping_worked = True + raise AuthenticationError( + message=f"ReplicateException - {error_str}", + llm_provider="replicate", + model=model, + response=original_exception.response, + ) + elif "input is too long" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"ReplicateException - {error_str}", + model=model, + llm_provider="replicate", + response=original_exception.response, + ) + elif exception_type == "ModelError": + exception_mapping_worked = True + raise BadRequestError( + message=f"ReplicateException - {error_str}", + model=model, + llm_provider="replicate", + response=original_exception.response, + ) + elif "Request was throttled" in error_str: + exception_mapping_worked = True + raise RateLimitError( + message=f"ReplicateException - {error_str}", + llm_provider="replicate", + model=model, + response=original_exception.response, + ) + elif hasattr(original_exception, "status_code"): + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"ReplicateException - {original_exception.message}", + llm_provider="replicate", + model=model, + response=original_exception.response, + ) + elif ( + original_exception.status_code == 400 + or original_exception.status_code == 422 + or original_exception.status_code == 413 + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"ReplicateException - {original_exception.message}", + model=model, + llm_provider="replicate", + response=original_exception.response, + ) + elif original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"ReplicateException - {original_exception.message}", + model=model, + llm_provider="replicate", + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"ReplicateException - {original_exception.message}", + llm_provider="replicate", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 500: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"ReplicateException - {original_exception.message}", + llm_provider="replicate", + model=model, + response=original_exception.response, + ) + exception_mapping_worked = True + raise APIError( + status_code=500, + message=f"ReplicateException - {str(original_exception)}", + llm_provider="replicate", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "bedrock": + if ( + "too many tokens" in error_str + or "expected maxLength:" in error_str + or "Input is too long" in error_str + or "prompt: length: 1.." in error_str + or "Too many input tokens" in error_str + ): + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"BedrockException: Context Window Error - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, + ) + if "Malformed input request" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"BedrockException - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, + ) + if ( + "Unable to locate credentials" in error_str + or "The security token included in the request is invalid" + in error_str + ): + exception_mapping_worked = True + raise AuthenticationError( + message=f"BedrockException Invalid Authentication - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, + ) + if "AccessDeniedException" in error_str: + exception_mapping_worked = True + raise PermissionDeniedError( + message=f"BedrockException PermissionDeniedError - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, + ) + if ( + "throttlingException" in error_str + or "ThrottlingException" in error_str + ): + exception_mapping_worked = True + raise RateLimitError( + message=f"BedrockException: Rate Limit Error - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, + ) + if "Connect timeout on endpoint URL" in error_str: + exception_mapping_worked = True + raise Timeout( + message=f"BedrockException: Timeout Error - {error_str}", + model=model, + llm_provider="bedrock", + ) + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 500: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"BedrockException - {original_exception.message}", + llm_provider="bedrock", + model=model, + response=httpx.Response( + status_code=500, + request=httpx.Request( + method="POST", url="https://api.openai.com/v1/" + ), + ), + ) + elif original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"BedrockException - {original_exception.message}", + llm_provider="bedrock", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 400: + exception_mapping_worked = True + raise BadRequestError( + message=f"BedrockException - {original_exception.message}", + llm_provider="bedrock", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 404: + exception_mapping_worked = True + raise NotFoundError( + message=f"BedrockException - {original_exception.message}", + llm_provider="bedrock", + model=model, + response=original_exception.response, + ) + elif custom_llm_provider == "sagemaker": + if "Unable to locate credentials" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"SagemakerException - {error_str}", + model=model, + llm_provider="sagemaker", + response=original_exception.response, + ) + elif ( + "Input validation error: `best_of` must be > 0 and <= 2" + in error_str + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", + model=model, + llm_provider="sagemaker", + response=original_exception.response, + ) + elif ( + "`inputs` tokens + `max_new_tokens` must be <=" in error_str + or "instance type with more CPU capacity or memory" in error_str + ): + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"SagemakerException - {error_str}", + model=model, + llm_provider="sagemaker", + response=original_exception.response, + ) + elif custom_llm_provider == "vertex_ai": + if ( + "Vertex AI API has not been used in project" in error_str + or "Unable to find your project" in error_str + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"VertexAIException - {error_str} {extra_information}", + model=model, + llm_provider="vertex_ai", + response=original_exception.response, + ) + elif ( + "None Unknown Error." in error_str + or "Content has no parts." in error_str + ): + exception_mapping_worked = True + raise APIError( + message=f"VertexAIException - {error_str} {extra_information}", + status_code=500, + model=model, + llm_provider="vertex_ai", + request=original_exception.request, + ) + elif "403" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"VertexAIException - {error_str} {extra_information}", + model=model, + llm_provider="vertex_ai", + response=original_exception.response, + ) + elif "The response was blocked." in error_str: + exception_mapping_worked = True + raise UnprocessableEntityError( + message=f"VertexAIException - {error_str} {extra_information}", + model=model, + llm_provider="vertex_ai", + response=httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ), + ) + elif ( + "429 Quota exceeded" in error_str + or "IndexError: list index out of range" in error_str + or "429 Unable to submit request because the service is temporarily out of capacity." + in error_str + ): + exception_mapping_worked = True + raise RateLimitError( + message=f"VertexAIException - {error_str} {extra_information}", + model=model, + llm_provider="vertex_ai", + response=httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ), + ) + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 400: + exception_mapping_worked = True + raise BadRequestError( + message=f"VertexAIException - {error_str} {extra_information}", + model=model, + llm_provider="vertex_ai", + response=original_exception.response, + ) + if original_exception.status_code == 500: + exception_mapping_worked = True + raise APIError( + message=f"VertexAIException - {error_str} {extra_information}", + status_code=500, + model=model, + llm_provider="vertex_ai", + request=original_exception.request, + ) + elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": + if "503 Getting metadata" in error_str: + # auth errors look like this + # 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate. + exception_mapping_worked = True + raise BadRequestError( + message=f"GeminiException - Invalid api key", + model=model, + llm_provider="palm", + response=original_exception.response, + ) + if ( + "504 Deadline expired before operation could complete." in error_str + or "504 Deadline Exceeded" in error_str + ): + exception_mapping_worked = True + raise Timeout( + message=f"GeminiException - {original_exception.message}", + model=model, + llm_provider="palm", + ) + if "400 Request payload size exceeds" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"GeminiException - {error_str}", + model=model, + llm_provider="palm", + response=original_exception.response, + ) + if ( + "500 An internal error has occurred." in error_str + or "list index out of range" in error_str + ): + exception_mapping_worked = True + raise APIError( + status_code=getattr(original_exception, "status_code", 500), + message=f"GeminiException - {original_exception.message}", + llm_provider="palm", + model=model, + request=httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ), + ) + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 400: + exception_mapping_worked = True + raise BadRequestError( + message=f"GeminiException - {error_str}", + model=model, + llm_provider="palm", + response=original_exception.response, + ) + # Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes + elif custom_llm_provider == "cloudflare": + if "Authentication error" in error_str: + exception_mapping_worked = True + raise AuthenticationError( + message=f"Cloudflare Exception - {original_exception.message}", + llm_provider="cloudflare", + model=model, + response=original_exception.response, + ) + if "must have required property" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"Cloudflare Exception - {original_exception.message}", + llm_provider="cloudflare", + model=model, + response=original_exception.response, + ) + elif ( + custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat" + ): # Cohere + if ( + "invalid api token" in error_str + or "No API key provided." in error_str + ): + exception_mapping_worked = True + raise AuthenticationError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + response=original_exception.response, + ) + elif "too many tokens" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"CohereException - {original_exception.message}", + model=model, + llm_provider="cohere", + response=original_exception.response, + ) + elif hasattr(original_exception, "status_code"): + if ( + original_exception.status_code == 400 + or original_exception.status_code == 498 + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 500: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + response=original_exception.response, + ) + elif ( + "CohereConnectionError" in exception_type + ): # cohere seems to fire these errors when we load test it (1k+ messages / min) + exception_mapping_worked = True + raise RateLimitError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + response=original_exception.response, + ) + elif "invalid type:" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + response=original_exception.response, + ) + elif "Unexpected server error" in error_str: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + response=original_exception.response, + ) + else: + if hasattr(original_exception, "status_code"): + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model, + request=original_exception.request, + ) + raise original_exception + elif custom_llm_provider == "huggingface": + if "length limit exceeded" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=error_str, + model=model, + llm_provider="huggingface", + response=original_exception.response, + ) + elif "A valid user token is required" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=error_str, + llm_provider="huggingface", + model=model, + response=original_exception.response, + ) + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"HuggingfaceException - {original_exception.message}", + llm_provider="huggingface", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 400: + exception_mapping_worked = True + raise BadRequestError( + message=f"HuggingfaceException - {original_exception.message}", + model=model, + llm_provider="huggingface", + response=original_exception.response, + ) + elif original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"HuggingfaceException - {original_exception.message}", + model=model, + llm_provider="huggingface", + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"HuggingfaceException - {original_exception.message}", + llm_provider="huggingface", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 503: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"HuggingfaceException - {original_exception.message}", + llm_provider="huggingface", + model=model, + response=original_exception.response, + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"HuggingfaceException - {original_exception.message}", + llm_provider="huggingface", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "ai21": + if hasattr(original_exception, "message"): + if "Prompt has too many tokens" in original_exception.message: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"AI21Exception - {original_exception.message}", + model=model, + llm_provider="ai21", + response=original_exception.response, + ) + if "Bad or missing API token." in original_exception.message: + exception_mapping_worked = True + raise BadRequestError( + message=f"AI21Exception - {original_exception.message}", + model=model, + llm_provider="ai21", + response=original_exception.response, + ) + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"AI21Exception - {original_exception.message}", + llm_provider="ai21", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"AI21Exception - {original_exception.message}", + model=model, + llm_provider="ai21", + ) + if original_exception.status_code == 422: + exception_mapping_worked = True + raise BadRequestError( + message=f"AI21Exception - {original_exception.message}", + model=model, + llm_provider="ai21", + response=original_exception.response, + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"AI21Exception - {original_exception.message}", + llm_provider="ai21", + model=model, + response=original_exception.response, + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"AI21Exception - {original_exception.message}", + llm_provider="ai21", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "nlp_cloud": + if "detail" in error_str: + if "Input text length should not exceed" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"NLPCloudException - {error_str}", + model=model, + llm_provider="nlp_cloud", + response=original_exception.response, + ) + elif "value is not a valid" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"NLPCloudException - {error_str}", + model=model, + llm_provider="nlp_cloud", + response=original_exception.response, + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=500, + message=f"NLPCloudException - {error_str}", + model=model, + llm_provider="nlp_cloud", + request=original_exception.request, + ) + if hasattr( + original_exception, "status_code" + ): # https://docs.nlpcloud.com/?shell#errors + if ( + original_exception.status_code == 400 + or original_exception.status_code == 406 + or original_exception.status_code == 413 + or original_exception.status_code == 422 + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"NLPCloudException - {original_exception.message}", + llm_provider="nlp_cloud", + model=model, + response=original_exception.response, + ) + elif ( + original_exception.status_code == 401 + or original_exception.status_code == 403 + ): + exception_mapping_worked = True + raise AuthenticationError( + message=f"NLPCloudException - {original_exception.message}", + llm_provider="nlp_cloud", + model=model, + response=original_exception.response, + ) + elif ( + original_exception.status_code == 522 + or original_exception.status_code == 524 + ): + exception_mapping_worked = True + raise Timeout( + message=f"NLPCloudException - {original_exception.message}", + model=model, + llm_provider="nlp_cloud", + ) + elif ( + original_exception.status_code == 429 + or original_exception.status_code == 402 + ): + exception_mapping_worked = True + raise RateLimitError( + message=f"NLPCloudException - {original_exception.message}", + llm_provider="nlp_cloud", + model=model, + response=original_exception.response, + ) + elif ( + original_exception.status_code == 500 + or original_exception.status_code == 503 + ): + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"NLPCloudException - {original_exception.message}", + llm_provider="nlp_cloud", + model=model, + request=original_exception.request, + ) + elif ( + original_exception.status_code == 504 + or original_exception.status_code == 520 + ): + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"NLPCloudException - {original_exception.message}", + model=model, + llm_provider="nlp_cloud", + response=original_exception.response, + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"NLPCloudException - {original_exception.message}", + llm_provider="nlp_cloud", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "together_ai": + import json + + try: + error_response = json.loads(error_str) + except: + error_response = {"error": error_str} + if ( + "error" in error_response + and "`inputs` tokens + `max_new_tokens` must be <=" + in error_response["error"] + ): + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"TogetherAIException - {error_response['error']}", + model=model, + llm_provider="together_ai", + response=original_exception.response, + ) + elif ( + "error" in error_response + and "invalid private key" in error_response["error"] + ): + exception_mapping_worked = True + raise AuthenticationError( + message=f"TogetherAIException - {error_response['error']}", + llm_provider="together_ai", + model=model, + response=original_exception.response, + ) + elif ( + "error" in error_response + and "INVALID_ARGUMENT" in error_response["error"] + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"TogetherAIException - {error_response['error']}", + model=model, + llm_provider="together_ai", + response=original_exception.response, + ) + + elif ( + "error" in error_response + and "API key doesn't match expected format." + in error_response["error"] + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"TogetherAIException - {error_response['error']}", + model=model, + llm_provider="together_ai", + response=original_exception.response, + ) + elif ( + "error_type" in error_response + and error_response["error_type"] == "validation" + ): + exception_mapping_worked = True + raise BadRequestError( + message=f"TogetherAIException - {error_response['error']}", + model=model, + llm_provider="together_ai", + response=original_exception.response, + ) + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"TogetherAIException - {original_exception.message}", + model=model, + llm_provider="together_ai", + ) + elif original_exception.status_code == 422: + exception_mapping_worked = True + raise BadRequestError( + message=f"TogetherAIException - {error_response['error']}", + model=model, + llm_provider="together_ai", + response=original_exception.response, + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"TogetherAIException - {original_exception.message}", + llm_provider="together_ai", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 524: + exception_mapping_worked = True + raise Timeout( + message=f"TogetherAIException - {original_exception.message}", + llm_provider="together_ai", + model=model, + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"TogetherAIException - {original_exception.message}", + llm_provider="together_ai", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "aleph_alpha": + if ( + "This is longer than the model's maximum context length" + in error_str + ): + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"AlephAlphaException - {original_exception.message}", + llm_provider="aleph_alpha", + model=model, + response=original_exception.response, + ) + elif "InvalidToken" in error_str or "No token provided" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"AlephAlphaException - {original_exception.message}", + llm_provider="aleph_alpha", + model=model, + response=original_exception.response, + ) + elif hasattr(original_exception, "status_code"): + print_verbose(f"status code: {original_exception.status_code}") + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"AlephAlphaException - {original_exception.message}", + llm_provider="aleph_alpha", + model=model, + ) + elif original_exception.status_code == 400: + exception_mapping_worked = True + raise BadRequestError( + message=f"AlephAlphaException - {original_exception.message}", + llm_provider="aleph_alpha", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"AlephAlphaException - {original_exception.message}", + llm_provider="aleph_alpha", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 500: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"AlephAlphaException - {original_exception.message}", + llm_provider="aleph_alpha", + model=model, + response=original_exception.response, + ) + raise original_exception + raise original_exception + elif ( + custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat" + ): + if isinstance(original_exception, dict): + error_str = original_exception.get("error", "") + else: + error_str = str(original_exception) + if "no such file or directory" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}", + model=model, + llm_provider="ollama", + response=original_exception.response, + ) + elif "Failed to establish a new connection" in error_str: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"OllamaException: {original_exception}", + llm_provider="ollama", + model=model, + response=original_exception.response, + ) + elif "Invalid response object from API" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"OllamaException: {original_exception}", + llm_provider="ollama", + model=model, + response=original_exception.response, + ) + elif "Read timed out" in error_str: + exception_mapping_worked = True + raise Timeout( + message=f"OllamaException: {original_exception}", + llm_provider="ollama", + model=model, + ) + elif custom_llm_provider == "vllm": + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 0: + exception_mapping_worked = True + raise APIConnectionError( + message=f"VLLMException - {original_exception.message}", + llm_provider="vllm", + model=model, + request=original_exception.request, + ) + elif custom_llm_provider == "azure": + if "Internal server error" in error_str: + exception_mapping_worked = True + raise APIError( + status_code=500, + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + request=httpx.Request(method="POST", url="https://openai.com/"), + ) + elif "This model's maximum context length is" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + response=original_exception.response, + ) + elif "DeploymentNotFound" in error_str: + exception_mapping_worked = True + raise NotFoundError( + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + response=original_exception.response, + ) + elif ( + "invalid_request_error" in error_str + and "content_policy_violation" in error_str + ) or ( + "The response was filtered due to the prompt triggering Azure OpenAI's content management" + in error_str + ): + exception_mapping_worked = True + raise ContentPolicyViolationError( + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + response=original_exception.response, + ) + elif "invalid_request_error" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + response=original_exception.response, + ) + elif ( + "The api_key client option must be set either by passing api_key to the client or by setting" + in error_str + ): + exception_mapping_worked = True + raise AuthenticationError( + message=f"{exception_provider} - {original_exception.message} {extra_information}", + llm_provider=custom_llm_provider, + model=model, + response=original_exception.response, + ) + elif hasattr(original_exception, "status_code"): + exception_mapping_worked = True + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + response=original_exception.response, + ) + elif original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"AzureException - {original_exception.message} {extra_information}", + model=model, + llm_provider="azure", + ) + if original_exception.status_code == 422: + exception_mapping_worked = True + raise BadRequestError( + message=f"AzureException - {original_exception.message} {extra_information}", + model=model, + llm_provider="azure", + response=original_exception.response, + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"AzureException - {original_exception.message} {extra_information}", + model=model, + llm_provider="azure", + response=original_exception.response, + ) + elif original_exception.status_code == 503: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"AzureException - {original_exception.message} {extra_information}", + model=model, + llm_provider="azure", + response=original_exception.response, + ) + elif original_exception.status_code == 504: # gateway timeout error + exception_mapping_worked = True + raise Timeout( + message=f"AzureException - {original_exception.message} {extra_information}", + model=model, + llm_provider="azure", + ) + else: + exception_mapping_worked = True + raise APIError( + status_code=original_exception.status_code, + message=f"AzureException - {original_exception.message} {extra_information}", + llm_provider="azure", + model=model, + request=httpx.Request( + method="POST", url="https://openai.com/" + ), + ) + else: + # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors + raise APIConnectionError( + message=f"{exception_provider} - {message} {extra_information}", + llm_provider="azure", + model=model, + request=httpx.Request(method="POST", url="https://openai.com/"), + ) + if ( + "BadRequestError.__init__() missing 1 required positional argument: 'param'" + in str(original_exception) + ): # deal with edge-case invalid request error bug in openai-python sdk + exception_mapping_worked = True + raise BadRequestError( + message=f"{exception_provider}: This can happen due to missing AZURE_API_VERSION: {str(original_exception)}", + model=model, + llm_provider=custom_llm_provider, + response=original_exception.response, + ) + else: # ensure generic errors always return APIConnectionError= + exception_mapping_worked = True + if hasattr(original_exception, "request"): + raise APIConnectionError( + message=f"{str(original_exception)}", + llm_provider=custom_llm_provider, + model=model, + request=original_exception.request, + ) + else: +> raise APIConnectionError( + message=f"{str(original_exception)}", + llm_provider=custom_llm_provider, + model=model, + request=httpx.Request( + method="POST", url="https://api.openai.com/v1/" + ), # stub the request + ) +E litellm.exceptions.APIConnectionError: 'stream' + +../utils.py:9328: APIConnectionError + +During handling of the above exception, another exception occurred: + +sync_mode = True + + @pytest.mark.parametrize("sync_mode", [True, False]) + @pytest.mark.asyncio + async def test_completion_predibase_streaming(sync_mode): + try: + litellm.set_verbose = True + + if sync_mode: + response = completion( + model="predibase/llama-3-8b-instruct", + tenant_id="c4768f95", + api_base="https://serving.app.predibase.com", + api_key=os.getenv("PREDIBASE_API_KEY"), + messages=[{"role": "user", "content": "What is the meaning of life?"}], + stream=True, + ) + + complete_response = "" + for idx, init_chunk in enumerate(response): + chunk, finished = streaming_format_tests(idx, init_chunk) + complete_response += chunk + custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"] + print(f"custom_llm_provider: {custom_llm_provider}") + assert custom_llm_provider == "predibase" + if finished: + assert isinstance( + init_chunk.choices[0], litellm.utils.StreamingChoices + ) + break + if complete_response.strip() == "": + raise Exception("Empty response received") + else: + response = await litellm.acompletion( + model="predibase/llama-3-8b-instruct", + tenant_id="c4768f95", + api_base="https://serving.app.predibase.com", + api_key=os.getenv("PREDIBASE_API_KEY"), + messages=[{"role": "user", "content": "What is the meaning of life?"}], + stream=True, + ) + + # await response + + complete_response = "" + idx = 0 + async for init_chunk in response: + chunk, finished = streaming_format_tests(idx, init_chunk) + complete_response += chunk + custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"] + print(f"custom_llm_provider: {custom_llm_provider}") + assert custom_llm_provider == "predibase" + idx += 1 + if finished: + assert isinstance( + init_chunk.choices[0], litellm.utils.StreamingChoices + ) + break + if complete_response.strip() == "": + raise Exception("Empty response received") + + print(f"complete_response: {complete_response}") + except litellm.Timeout as e: + pass + except Exception as e: +> pytest.fail(f"Error occurred: {e}") +E Failed: Error occurred: 'stream' + +test_streaming.py:373: Failed +---------------------------- Captured stdout setup ----------------------------- + +----------------------------- Captured stdout call ----------------------------- + + +Request to litellm: +litellm.completion(model='predibase/llama-3-8b-instruct', tenant_id='c4768f95', api_base='https://serving.app.predibase.com', api_key='pb_Qg9YbQo7UqqHdu0ozxN_aw', messages=[{'role': 'user', 'content': 'What is the meaning of life?'}], stream=True) + + +self.optional_params: {} +SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False +UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model=llama-3-8b-instruct, custom_llm_provider=predibase +Final returned optional params: {'stream': True, 'tenant_id': 'c4768f95'} +self.optional_params: {'stream': True, 'tenant_id': 'c4768f95'} + + +POST Request Sent from LiteLLM: +curl -X POST \ +https://serving.app.predibase.com/c4768f95/deployments/v2/llms/llama-3-8b-instruct/generate_stream \ +-H 'content-type: application/json' -H 'Authorization: Bearer pb_Qg********************' \ +-d '{'inputs': 'What is the meaning of life?', 'parameters': {'details': True, 'max_new_tokens': 256, 'return_full_text': False}}' + + + +Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new +LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. + +Logging Details: logger_fn - None | callable(logger_fn) - False +Logging Details LiteLLM-Failure Call +self.failure_callback: [] =============================== warnings summary =============================== -../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 24 warnings +../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 25 warnings /opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning) @@ -47,86 +6906,90 @@ final cost: 0; prompt_tokens_cost_usd_dollar: 0; completion_tokens_cost_usd_doll /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:454: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:466 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:466: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:474 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:474: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:509 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:509: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:487 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:487: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:546 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:546: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:532 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:532: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:840 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:840: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:569 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:569: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:867 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:867: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:864 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:864: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../proxy/_types.py:886 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:886: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ +../proxy/_types.py:891 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:891: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ @root_validator(pre=True) -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:121 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:121: DeprecationWarning: pkg_resources is deprecated as an API - warnings.warn("pkg_resources is deprecated as an API", DeprecationWarning) +../proxy/_types.py:912 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:912: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ + @root_validator(pre=True) -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870: 10 warnings - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`. +../utils.py:39 + /Users/krrishdholakia/Documents/litellm/litellm/utils.py:39: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html + import pkg_resources # type: ignore + +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: 10 warnings + /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`. Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages declare_namespace(pkg) -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.cloud')`. +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 + /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.cloud')`. Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages declare_namespace(pkg) -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2349 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2349 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2349 -test_completion.py::test_gemini_completion_call_error - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2349: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`. +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2317 +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2317 +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2317 + /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2317: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`. Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages declare_namespace(parent) -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.logging')`. +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 + /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.logging')`. Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages declare_namespace(pkg) -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.iam')`. +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 + /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.iam')`. Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages declare_namespace(pkg) -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('mpl_toolkits')`. +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 + /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('mpl_toolkits')`. Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages declare_namespace(pkg) -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('sphinxcontrib')`. +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 + /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('sphinxcontrib')`. Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages declare_namespace(pkg) -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870 -../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870 - /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2870: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('zope')`. +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 +../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832 + /opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2832: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('zope')`. Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages declare_namespace(pkg) -test_completion.py::test_gemini_completion_call_error - /opt/homebrew/lib/python3.11/site-packages/google/rpc/__init__.py:20: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.rpc')`. - Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages - pkg_resources.declare_namespace(__name__) +test_streaming.py::test_completion_predibase_streaming[False] + /opt/homebrew/lib/python3.11/site-packages/httpx/_content.py:204: DeprecationWarning: Use 'content=<...>' to upload raw bytes/text content. + warnings.warn(message, DeprecationWarning) -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html -======================== 1 passed, 63 warnings in 1.48s ======================== +=========================== short test summary info ============================ +FAILED test_streaming.py::test_completion_predibase_streaming[True] - Failed:... +=================== 1 failed, 1 passed, 64 warnings in 5.28s =================== diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 7f0977b15..f726ed95a 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -85,20 +85,33 @@ def test_completion_azure_command_r(): pytest.fail(f"Error occurred: {e}") -@pytest.mark.skip(reason="local test") -def test_completion_predibase(): +# @pytest.mark.skip(reason="local test") +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_completion_predibase(sync_mode): try: litellm.set_verbose = True - response = completion( - model="predibase/llama-3-8b-instruct", - tenant_id="c4768f95", - api_base="https://serving.app.predibase.com", - api_key=os.getenv("PREDIBASE_API_KEY"), - messages=[{"role": "user", "content": "What is the meaning of life?"}], - ) + if sync_mode: + response = completion( + model="predibase/llama-3-8b-instruct", + tenant_id="c4768f95", + api_base="https://serving.app.predibase.com", + api_key=os.getenv("PREDIBASE_API_KEY"), + messages=[{"role": "user", "content": "What is the meaning of life?"}], + ) - print(response) + print(response) + else: + response = await litellm.acompletion( + model="predibase/llama-3-8b-instruct", + tenant_id="c4768f95", + api_base="https://serving.app.predibase.com", + api_key=os.getenv("PREDIBASE_API_KEY"), + messages=[{"role": "user", "content": "What is the meaning of life?"}], + ) + + print(response) except litellm.Timeout as e: pass except Exception as e: diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 7d639d7a3..7e5a10265 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -5,6 +5,7 @@ import sys, os, asyncio import traceback import time, pytest from pydantic import BaseModel +from typing import Tuple sys.path.insert( 0, os.path.abspath("../..") @@ -142,7 +143,7 @@ def validate_last_format(chunk): ), "'finish_reason' should be a string." -def streaming_format_tests(idx, chunk): +def streaming_format_tests(idx, chunk) -> Tuple[str, bool]: extracted_chunk = "" finished = False print(f"chunk: {chunk}") @@ -306,6 +307,70 @@ def test_completion_azure_stream(): # test_completion_azure_stream() +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_completion_predibase_streaming(sync_mode): + try: + litellm.set_verbose = True + + if sync_mode: + response = completion( + model="predibase/llama-3-8b-instruct", + tenant_id="c4768f95", + api_base="https://serving.app.predibase.com", + api_key=os.getenv("PREDIBASE_API_KEY"), + messages=[{"role": "user", "content": "What is the meaning of life?"}], + stream=True, + ) + + complete_response = "" + for idx, init_chunk in enumerate(response): + chunk, finished = streaming_format_tests(idx, init_chunk) + complete_response += chunk + custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"] + print(f"custom_llm_provider: {custom_llm_provider}") + assert custom_llm_provider == "predibase" + if finished: + assert isinstance( + init_chunk.choices[0], litellm.utils.StreamingChoices + ) + break + if complete_response.strip() == "": + raise Exception("Empty response received") + else: + response = await litellm.acompletion( + model="predibase/llama-3-8b-instruct", + tenant_id="c4768f95", + api_base="https://serving.app.predibase.com", + api_key=os.getenv("PREDIBASE_API_KEY"), + messages=[{"role": "user", "content": "What is the meaning of life?"}], + stream=True, + ) + + # await response + + complete_response = "" + idx = 0 + async for init_chunk in response: + chunk, finished = streaming_format_tests(idx, init_chunk) + complete_response += chunk + custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"] + print(f"custom_llm_provider: {custom_llm_provider}") + assert custom_llm_provider == "predibase" + idx += 1 + if finished: + assert isinstance( + init_chunk.choices[0], litellm.utils.StreamingChoices + ) + break + if complete_response.strip() == "": + raise Exception("Empty response received") + + print(f"complete_response: {complete_response}") + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") def test_completion_azure_function_calling_stream(): diff --git a/litellm/utils.py b/litellm/utils.py index 7ccb5e8ff..fe52e6c63 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9757,6 +9757,50 @@ class CustomStreamWrapper: "finish_reason": finish_reason, } + def handle_predibase_chunk(self, chunk): + try: + if type(chunk) != str: + chunk = chunk.decode( + "utf-8" + ) # DO NOT REMOVE this: This is required for HF inference API + Streaming + text = "" + is_finished = False + finish_reason = "" + print_verbose(f"chunk: {chunk}") + if chunk.startswith("data:"): + data_json = json.loads(chunk[5:]) + print_verbose(f"data json: {data_json}") + if "token" in data_json and "text" in data_json["token"]: + text = data_json["token"]["text"] + if data_json.get("details", False) and data_json["details"].get( + "finish_reason", False + ): + is_finished = True + finish_reason = data_json["details"]["finish_reason"] + elif data_json.get( + "generated_text", False + ): # if full generated text exists, then stream is complete + text = "" # don't return the final bos token + is_finished = True + finish_reason = "stop" + elif data_json.get("error", False): + raise Exception(data_json.get("error")) + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + elif "error" in chunk: + raise ValueError(chunk) + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + except Exception as e: + traceback.print_exc() + raise e + def handle_huggingface_chunk(self, chunk): try: if type(chunk) != str: @@ -10391,6 +10435,11 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "predibase": + response_obj = self.handle_predibase_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] elif ( self.custom_llm_provider and self.custom_llm_provider == "baseten" ): # baseten doesn't provide streaming @@ -11008,6 +11057,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "sagemaker" or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "cached_response" + or self.custom_llm_provider == "predibase" or self.custom_llm_provider in litellm.openai_compatible_endpoints ): async for chunk in self.completion_stream: From 4a7be9163bb5dc4bbedf47c8b82b6bb18ee80bfb Mon Sep 17 00:00:00 2001 From: CyanideByte Date: Thu, 9 May 2024 17:42:19 -0700 Subject: [PATCH 28/34] Globally filtering pydantic conflict warnings --- litellm/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/litellm/__init__.py b/litellm/__init__.py index 4f72504f6..6f8d4e8df 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1,3 +1,6 @@ +### Hide pydantic namespace conflict warnings globally ### +import warnings +warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*") ### INIT VARIABLES ### import threading, requests, os from typing import Callable, List, Optional, Dict, Union, Any, Literal From 9083d8e490b485b9e72b75091b1a530833228ad0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 17:55:27 -0700 Subject: [PATCH 29/34] fix: fix linting errors --- .pre-commit-config.yaml | 16 ++++++++-------- litellm/llms/huggingface_restapi.py | 7 ++++--- litellm/llms/predibase.py | 8 ++++---- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e8bb1ff66..cc41d85f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,11 +16,11 @@ repos: name: Check if files match entry: python3 ci_cd/check_files_match.py language: system -# - repo: local -# hooks: -# - id: mypy -# name: mypy -# entry: python3 -m mypy --ignore-missing-imports -# language: system -# types: [python] -# files: ^litellm/ \ No newline at end of file +- repo: local + hooks: + - id: mypy + name: mypy + entry: python3 -m mypy --ignore-missing-imports + language: system + types: [python] + files: ^litellm/ \ No newline at end of file diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index b250f3013..a2c4457c2 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -399,10 +399,11 @@ class Huggingface(BaseLLM): data = { "inputs": prompt, "parameters": optional_params, - "stream": ( + "stream": ( # type: ignore True if "stream" in optional_params - and optional_params["stream"] == True + and isinstance(optional_params["stream"], bool) + and optional_params["stream"] == True # type: ignore else False ), } @@ -433,7 +434,7 @@ class Huggingface(BaseLLM): data = { "inputs": prompt, "parameters": inference_params, - "stream": ( + "stream": ( # type: ignore True if "stream" in optional_params and optional_params["stream"] == True diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index f3935984d..ef9c6b0ba 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -129,7 +129,7 @@ class PredibaseChatCompletion(BaseLLM): ) super().__init__() - def validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict: + def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict: if api_key is None: raise ValueError( "Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params" @@ -309,7 +309,7 @@ class PredibaseChatCompletion(BaseLLM): logger_fn=None, headers: dict = {}, ) -> Union[ModelResponse, CustomStreamWrapper]: - headers = self.validate_environment(api_key, headers) + headers = self._validate_environment(api_key, headers) completion_url = "" input_text = "" base_url = "https://serving.app.predibase.com" @@ -411,13 +411,13 @@ class PredibaseChatCompletion(BaseLLM): data=json.dumps(data), stream=stream, ) - response = CustomStreamWrapper( + _response = CustomStreamWrapper( response.iter_lines(), model, custom_llm_provider="predibase", logging_obj=logging_obj, ) - return response + return _response ### SYNC COMPLETION else: response = requests.post( From 425efc60f48421f7cbad2ce40fb4e6fabfc90a1b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 18:12:28 -0700 Subject: [PATCH 30/34] fix(main.py): fix linting error --- litellm/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index cec49d35f..273de4d2e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1810,7 +1810,7 @@ def completion( or get_secret("PREDIBASE_API_KEY") ) - model_response = predibase_chat_completions.completion( + _model_response = predibase_chat_completions.completion( model=model, messages=messages, model_response=model_response, @@ -1832,8 +1832,8 @@ def completion( and optional_params["stream"] == True and acompletion == False ): - return response - response = model_response + return _model_response + response = _model_response elif custom_llm_provider == "ai21": custom_llm_provider = "ai21" ai21_key = ( From 5a38438c3f32a6bdd9366fd0b42fe8673df26024 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 18:40:45 -0700 Subject: [PATCH 31/34] docs(customer_routing.md): add region-based routing for specific customers, to docs --- .../my-website/docs/proxy/customer_routing.md | 83 +++++++++++++++++++ docs/my-website/sidebars.js | 1 + 2 files changed, 84 insertions(+) create mode 100644 docs/my-website/docs/proxy/customer_routing.md diff --git a/docs/my-website/docs/proxy/customer_routing.md b/docs/my-website/docs/proxy/customer_routing.md new file mode 100644 index 000000000..4c8a60af8 --- /dev/null +++ b/docs/my-website/docs/proxy/customer_routing.md @@ -0,0 +1,83 @@ +# Region-based Routing + +Route specific customers to eu-only models. + +By specifying 'allowed_model_region' for a customer, LiteLLM will filter-out any models in a model group which is not in the allowed region (i.e. 'eu'). + +[**See Code**](https://github.com/BerriAI/litellm/blob/5eb12e30cc5faa73799ebc7e48fc86ebf449c879/litellm/router.py#L2938) + +### 1. Create customer with region-specification + +Use the litellm 'end-user' object for this. + +End-users can be tracked / id'ed by passing the 'user' param to litellm in an openai chat completion/embedding call. + +```bash +curl -X POST --location 'http://0.0.0.0:4000/end_user/new' \ +--header 'Authorization: Bearer sk-1234' \ +--header 'Content-Type: application/json' \ +--data '{ + "user_id" : "ishaan-jaff-45", + "allowed_model_region": "eu", # 👈 SPECIFY ALLOWED REGION='eu' +}' +``` + +### 2. Add eu models to model-group + +Add eu models to a model group. For azure models, litellm can automatically infer the region (no need to set it). + +```yaml +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: azure/gpt-35-turbo-eu # 👈 EU azure model + api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ + api_key: os.environ/AZURE_EUROPE_API_KEY + - model_name: gpt-3.5-turbo + litellm_params: + model: azure/chatgpt-v-2 + api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ + api_version: "2023-05-15" + api_key: os.environ/AZURE_API_KEY + +router_settings: + enable_pre_call_checks: true # 👈 IMPORTANT +``` + +Start the proxy + +```yaml +litellm --config /path/to/config.yaml +``` + +### 3. Test it! + +Make a simple chat completions call to the proxy. In the response headers, you should see the returned api base. + +```bash +curl -X POST --location 'http://localhost:4000/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer sk-1234' \ +--data '{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "what is the meaning of the universe? 1234" + }], + "user": "ishaan-jaff-45" # 👈 USER ID +} +' +``` + +Expected API Base in response headers + +``` +x-litellm-api-base: "https://my-endpoint-europe-berri-992.openai.azure.com/" +``` + +### FAQ + +**What happens if there are no available models for that region?** + +Since the router filters out models not in the specified region, it will return back as an error to the user, if no models in that region are available. \ No newline at end of file diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index d00d853a0..3c968ea57 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -50,6 +50,7 @@ const sidebars = { items: ["proxy/logging", "proxy/streaming_logging"], }, "proxy/team_based_routing", + "proxy/customer_routing", "proxy/ui", "proxy/cost_tracking", "proxy/token_auth", From 491e17734856fbd28260bbd3222b3ca9c49a86e0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 18:44:19 -0700 Subject: [PATCH 32/34] fix(predibase.py): fix async completion call --- litellm/llms/predibase.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index ef9c6b0ba..1f96e7c67 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -390,7 +390,7 @@ class PredibaseChatCompletion(BaseLLM): model=model, messages=messages, data=data, - api_base=api_base, + api_base=completion_url, model_response=model_response, print_verbose=print_verbose, encoding=encoding, From 76d4290591901961ac97cd53feee820b329b4dd5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 19:07:19 -0700 Subject: [PATCH 33/34] fix(predibase.py): fix event loop closed error --- litellm/llms/predibase.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index 1f96e7c67..e377b02d4 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -124,9 +124,6 @@ class PredibaseConfig: class PredibaseChatCompletion(BaseLLM): def __init__(self) -> None: - self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=litellm.request_timeout, connect=5.0) - ) super().__init__() def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict: @@ -457,8 +454,10 @@ class PredibaseChatCompletion(BaseLLM): logger_fn=None, headers={}, ) -> ModelResponse: - - response = await self.async_handler.post( + async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + response = await async_handler.post( api_base, headers=headers, data=json.dumps(data) ) return self.process_response( @@ -491,9 +490,11 @@ class PredibaseChatCompletion(BaseLLM): logger_fn=None, headers={}, ) -> CustomStreamWrapper: - + async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) data["stream"] = True - response = await self.async_handler.post( + response = await async_handler.post( url="https://serving.app.predibase.com/c4768f95/deployments/v2/llms/llama-3-8b-instruct/generate_stream", headers=headers, data=json.dumps(data), From 714370956fe74c503da238750d2a35df47beaeec Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 9 May 2024 22:18:16 -0700 Subject: [PATCH 34/34] fix(predibase.py): fix async streaming --- litellm/llms/predibase.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index e377b02d4..c3424d244 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -454,10 +454,10 @@ class PredibaseChatCompletion(BaseLLM): logger_fn=None, headers={}, ) -> ModelResponse: - async_handler = AsyncHTTPHandler( + self.async_handler = AsyncHTTPHandler( timeout=httpx.Timeout(timeout=600.0, connect=5.0) ) - response = await async_handler.post( + response = await self.async_handler.post( api_base, headers=headers, data=json.dumps(data) ) return self.process_response( @@ -490,12 +490,12 @@ class PredibaseChatCompletion(BaseLLM): logger_fn=None, headers={}, ) -> CustomStreamWrapper: - async_handler = AsyncHTTPHandler( + self.async_handler = AsyncHTTPHandler( timeout=httpx.Timeout(timeout=600.0, connect=5.0) ) data["stream"] = True - response = await async_handler.post( - url="https://serving.app.predibase.com/c4768f95/deployments/v2/llms/llama-3-8b-instruct/generate_stream", + response = await self.async_handler.post( + url=api_base, headers=headers, data=json.dumps(data), stream=True,