fix: allow api base to be set for all providers

enables proxy use cases
This commit is contained in:
Krrish Dholakia 2023-10-19 19:07:42 -07:00
parent 72f55a4e6c
commit 00993f3575
7 changed files with 76 additions and 11 deletions

View file

@ -74,8 +74,8 @@ class ReplicateConfig():
# Function to start a prediction and get the prediction URL
def start_prediction(version_id, input_data, api_token, logging_obj):
base_url = "https://api.replicate.com/v1"
def start_prediction(version_id, input_data, api_token, api_base, logging_obj):
base_url = api_base
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json"
@ -159,6 +159,7 @@ def model_to_version_id(model):
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj,
@ -208,7 +209,7 @@ def completion(
## Step2: Poll prediction url for response
## Step2: is handled with and without streaming
model_response["created"] = time.time() # for pricing this must remain right before calling api
prediction_url = start_prediction(version_id, input_data, api_key, logging_obj=logging_obj)
prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj)
print_verbose(prediction_url)
# Handle the prediction response (streaming or non-streaming)