This commit is contained in:
Krrish Dholakia 2023-09-06 11:21:39 -07:00
parent af60b2ba77
commit 48ee4a08ac
7 changed files with 16 additions and 13 deletions

View file

@ -14,7 +14,7 @@ class ReplicateError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# Function to start a prediction and get the prediction URL # Function to start a prediction and get the prediction URL
def start_prediction(version_id, input_data, api_token): def start_prediction(version_id, input_data, api_token, logging_obj):
base_url = "https://api.replicate.com/v1" base_url = "https://api.replicate.com/v1"
headers = { headers = {
"Authorization": f"Token {api_token}", "Authorization": f"Token {api_token}",
@ -27,12 +27,19 @@ def start_prediction(version_id, input_data, api_token):
"max_new_tokens": 500, "max_new_tokens": 500,
} }
## LOGGING
logging_obj.pre_call(
input=input_data["prompt"],
api_key="",
additional_args={"complete_input_dict": initial_prediction_data, "headers": headers},
)
response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers) response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers)
if response.status_code == 201: if response.status_code == 201:
response_data = response.json() response_data = response.json()
return response_data.get("urls", {}).get("get") return response_data.get("urls", {}).get("get")
else: else:
raise ReplicateError(response.status_code, "Failed to start prediction.") raise ReplicateError(response.status_code, message=response.text)
# Function to handle prediction response (non-streaming) # Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose): def handle_prediction_response(prediction_url, api_token, print_verbose):
@ -111,18 +118,12 @@ def completion(
**optional_params **optional_params
} }
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": input_data},
)
## COMPLETION CALL ## COMPLETION CALL
## Replicate Compeltion calls have 2 steps ## Replicate Compeltion calls have 2 steps
## Step1: Start Prediction: gets a prediction url ## Step1: Start Prediction: gets a prediction url
## Step2: Poll prediction url for response ## Step2: Poll prediction url for response
## Step2: is handled with and without streaming ## Step2: is handled with and without streaming
prediction_url = start_prediction(version_id, input_data, api_key) prediction_url = start_prediction(version_id, input_data, api_key, logging_obj=logging_obj)
print_verbose(prediction_url) print_verbose(prediction_url)
# Handle the prediction response (streaming or non-streaming) # Handle the prediction response (streaming or non-streaming)

View file

@ -18,13 +18,15 @@ model_alias_map = {
litellm.model_alias_map = model_alias_map litellm.model_alias_map = model_alias_map
print( try:
completion( completion(
"llama2", "llama2",
messages=[{"role": "user", "content": "Hey, how's it going?"}], messages=[{"role": "user", "content": "Hey, how's it going?"}],
top_p=0.1, top_p=0.1,
temperature=0, temperature=0.1,
num_beams=4, num_beams=4,
max_tokens=60, max_tokens=60,
) )
) except Exception as e:
print(e.status_code)
print(e)

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.541" version = "0.1.542"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"