mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
improve replicate usage
This commit is contained in:
parent
1bf4bfa85f
commit
dc0c084813
2 changed files with 30 additions and 20 deletions
|
@ -116,7 +116,7 @@ def completion(
|
|||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
model_response = ModelResponse()
|
||||
if deployment_id != None:
|
||||
if deployment_id != None: # azure llms
|
||||
model=deployment_id
|
||||
custom_llm_provider="azure"
|
||||
elif (
|
||||
|
@ -124,10 +124,7 @@ def completion(
|
|||
): # allow custom provider to be passed in via the model name "azure/chatgpt-test"
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
if (
|
||||
"replicate" == custom_llm_provider and "/" not in model
|
||||
): # handle the "replicate/llama2..." edge-case
|
||||
model = custom_llm_provider + "/" + model
|
||||
|
||||
# check if user passed in any of the OpenAI optional params
|
||||
optional_params = get_optional_params(
|
||||
functions=functions,
|
||||
|
@ -340,22 +337,19 @@ def completion(
|
|||
model_response["model"] = model
|
||||
model_response["usage"] = response["usage"]
|
||||
response = model_response
|
||||
elif "replicate" in model or custom_llm_provider == "replicate":
|
||||
# import replicate/if it fails then pip install replicate
|
||||
|
||||
|
||||
elif (
|
||||
"replicate" in model or
|
||||
custom_llm_provider == "replicate" or
|
||||
model in litellm.replicate_models
|
||||
):
|
||||
# Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN")
|
||||
replicate_key = os.environ.get("REPLICATE_API_TOKEN")
|
||||
if replicate_key == None:
|
||||
# user did not set REPLICATE_API_TOKEN in .env
|
||||
replicate_key = (
|
||||
get_secret("REPLICATE_API_KEY")
|
||||
or get_secret("REPLICATE_API_TOKEN")
|
||||
or api_key
|
||||
or litellm.replicate_key
|
||||
)
|
||||
# set replicate key
|
||||
os.environ["REPLICATE_API_TOKEN"] = str(replicate_key)
|
||||
replicate_key = None
|
||||
replicate_key = (
|
||||
get_secret("REPLICATE_API_KEY")
|
||||
or get_secret("REPLICATE_API_TOKEN")
|
||||
or api_key
|
||||
or litellm.replicate_key
|
||||
)
|
||||
|
||||
model_response = replicate.completion(
|
||||
model=model,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue