mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix: allow api base to be set for all providers
enables proxy use cases
This commit is contained in:
parent
72f55a4e6c
commit
00993f3575
7 changed files with 76 additions and 11 deletions
|
@ -74,8 +74,8 @@ class ReplicateConfig():
|
|||
|
||||
|
||||
# Function to start a prediction and get the prediction URL
|
||||
def start_prediction(version_id, input_data, api_token, logging_obj):
|
||||
base_url = "https://api.replicate.com/v1"
|
||||
def start_prediction(version_id, input_data, api_token, api_base, logging_obj):
|
||||
base_url = api_base
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json"
|
||||
|
@ -159,6 +159,7 @@ def model_to_version_id(model):
|
|||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
logging_obj,
|
||||
|
@ -208,7 +209,7 @@ def completion(
|
|||
## Step2: Poll prediction url for response
|
||||
## Step2: is handled with and without streaming
|
||||
model_response["created"] = time.time() # for pricing this must remain right before calling api
|
||||
prediction_url = start_prediction(version_id, input_data, api_key, logging_obj=logging_obj)
|
||||
prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj)
|
||||
print_verbose(prediction_url)
|
||||
|
||||
# Handle the prediction response (streaming or non-streaming)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue