diff --git a/litellm/llms/ai21.py b/litellm/llms/ai21.py index 4e7d6b1fe..651ab92ec 100644 --- a/litellm/llms/ai21.py +++ b/litellm/llms/ai21.py @@ -92,6 +92,7 @@ def validate_environment(api_key): def completion( model: str, messages: list, + api_base: str, model_response: ModelResponse, print_verbose: Callable, encoding, @@ -137,7 +138,7 @@ def completion( ) ## COMPLETION CALL response = requests.post( - "https://api.ai21.com/studio/v1/" + model + "/complete", headers=headers, data=json.dumps(data) + api_base + model + "/complete", headers=headers, data=json.dumps(data) ) if "stream" in optional_params and optional_params["stream"] == True: return response.iter_lines() diff --git a/litellm/llms/aleph_alpha.py b/litellm/llms/aleph_alpha.py index f9ff5536e..0e83b76a7 100644 --- a/litellm/llms/aleph_alpha.py +++ b/litellm/llms/aleph_alpha.py @@ -160,6 +160,7 @@ def validate_environment(api_key): def completion( model: str, messages: list, + api_base: str, model_response: ModelResponse, print_verbose: Callable, encoding, @@ -178,7 +179,7 @@ def completion( if k not in optional_params: # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v - completion_url = "https://api.aleph-alpha.com/complete" + completion_url = api_base model = model prompt = "" if "control" in model: # follow the ###Instruction / ###Response format diff --git a/litellm/llms/cohere.py b/litellm/llms/cohere.py index 001c6ba26..cd6032c56 100644 --- a/litellm/llms/cohere.py +++ b/litellm/llms/cohere.py @@ -100,6 +100,7 @@ def validate_environment(api_key): def completion( model: str, messages: list, + api_base: str, model_response: ModelResponse, print_verbose: Callable, encoding, @@ -110,7 +111,7 @@ def completion( logger_fn=None, ): headers = validate_environment(api_key) - completion_url = "https://api.cohere.ai/v1/generate" + completion_url = api_base model = model prompt = " ".join(message["content"] for message in messages) diff --git a/litellm/llms/nlp_cloud.py b/litellm/llms/nlp_cloud.py index fe3cd26fc..b12c23ff5 100644 --- a/litellm/llms/nlp_cloud.py +++ b/litellm/llms/nlp_cloud.py @@ -96,6 +96,7 @@ def validate_environment(api_key): def completion( model: str, messages: list, + api_base: str, model_response: ModelResponse, print_verbose: Callable, encoding, @@ -114,7 +115,7 @@ def completion( if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v - completion_url_fragment_1 = "https://api.nlpcloud.io/v1/gpu/" + completion_url_fragment_1 = api_base completion_url_fragment_2 = "/generation" model = model text = " ".join(message["content"] for message in messages) @@ -125,6 +126,7 @@ def completion( } completion_url = completion_url_fragment_1 + model + completion_url_fragment_2 + ## LOGGING logging_obj.pre_call( input=text, diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index 7aebdf5fa..0912af5c0 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -74,8 +74,8 @@ class ReplicateConfig(): # Function to start a prediction and get the prediction URL -def start_prediction(version_id, input_data, api_token, logging_obj): - base_url = "https://api.replicate.com/v1" +def start_prediction(version_id, input_data, api_token, api_base, logging_obj): + base_url = api_base headers = { "Authorization": f"Token {api_token}", "Content-Type": "application/json" @@ -159,6 +159,7 @@ def model_to_version_id(model): def completion( model: str, messages: list, + api_base: str, model_response: ModelResponse, print_verbose: Callable, logging_obj, @@ -208,7 +209,7 @@ def completion( ## Step2: Poll prediction url for response ## Step2: is handled with and without streaming model_response["created"] = time.time() # for pricing this must remain right before calling api - prediction_url = start_prediction(version_id, input_data, api_key, logging_obj=logging_obj) + prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj) print_verbose(prediction_url) # Handle the prediction response (streaming or non-streaming) diff --git a/litellm/llms/together_ai.py b/litellm/llms/together_ai.py index 2a03952e0..9fc48b4f6 100644 --- a/litellm/llms/together_ai.py +++ b/litellm/llms/together_ai.py @@ -80,6 +80,7 @@ def validate_environment(api_key): def completion( model: str, messages: list, + api_base: str, model_response: ModelResponse, print_verbose: Callable, encoding, @@ -129,7 +130,7 @@ def completion( and optional_params["stream_tokens"] == True ): response = requests.post( - "https://api.together.xyz/inference", + api_base, headers=headers, data=json.dumps(data), stream=optional_params["stream_tokens"], @@ -137,7 +138,7 @@ def completion( return response.iter_lines() else: response = requests.post( - "https://api.together.xyz/inference", + api_base, headers=headers, data=json.dumps(data) ) diff --git a/litellm/main.py b/litellm/main.py index b05caa453..37927a337 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -560,9 +560,17 @@ def completion( or get_secret("REPLICATE_API_TOKEN") ) + api_base = ( + api_base + or litellm.api_base + or get_secret("REPLICATE_API_BASE") + or "https://api.replicate.com/v1" + ) + model_response = replicate.completion( model=model, messages=messages, + api_base=api_base, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, @@ -611,9 +619,17 @@ def completion( api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key ) + api_base = ( + api_base + or litellm.api_base + or get_secret("NLP_CLOUD_API_BASE") + or "https://api.nlpcloud.io/v1/gpu/" + ) + model_response = nlp_cloud.completion( model=model, messages=messages, + api_base=api_base, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, @@ -634,9 +650,17 @@ def completion( api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key ) + api_base = ( + api_base + or litellm.api_base + or get_secret("ALEPH_ALPHA_API_BASE") + or "https://api.aleph-alpha.com/complete" + ) + model_response = aleph_alpha.completion( model=model, messages=messages, + api_base=api_base, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, @@ -661,9 +685,18 @@ def completion( or get_secret("CO_API_KEY") or litellm.api_key ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("COHERE_API_BASE") + or "https://api.cohere.ai/v1/generate" + ) + model_response = cohere.completion( model=model, messages=messages, + api_base=api_base, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, @@ -687,6 +720,14 @@ def completion( litellm.openai_key or get_secret("DEEPINFRA_API_KEY") ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("DEEPINFRA_API_BASE") + or "https://api.deepinfra.com/v1/openai" + ) + ## LOGGING logging.pre_call( input=messages, @@ -698,7 +739,7 @@ def completion( response = openai.ChatCompletion.create( model=model, messages=messages, - api_base="https://api.deepinfra.com/v1/openai", # use the deepinfra api base + api_base=api_base, # use the deepinfra api base api_type="openai", api_version=api_version, # default None **optional_params, @@ -840,10 +881,18 @@ def completion( or get_secret("TOGETHERAI_API_KEY") or litellm.api_key ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("TOGETHERAI_API_BASE") + or "https://api.together.xyz/inference" + ) model_response = together_ai.completion( model=model, messages=messages, + api_base=api_base, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, @@ -923,10 +972,19 @@ def completion( or litellm.ai21_key or os.environ.get("AI21_API_KEY") or litellm.api_key - ) + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("AI21_API_BASE") + or "https://api.ai21.com/studio/v1/" + ) + model_response = ai21.completion( model=model, messages=messages, + api_base=api_base, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params,