mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
use replicate http requests instead
This commit is contained in:
parent
3d6836417e
commit
c45b132675
3 changed files with 188 additions and 80 deletions
|
@ -24,6 +24,7 @@ from .llms import ai21
|
|||
from .llms import sagemaker
|
||||
from .llms import bedrock
|
||||
from .llms import huggingface_restapi
|
||||
from .llms import replicate
|
||||
from .llms import aleph_alpha
|
||||
from .llms import baseten
|
||||
import tiktoken
|
||||
|
@ -341,10 +342,7 @@ def completion(
|
|||
response = model_response
|
||||
elif "replicate" in model or custom_llm_provider == "replicate":
|
||||
# import replicate/if it fails then pip install replicate
|
||||
try:
|
||||
import replicate
|
||||
except:
|
||||
Exception("Replicate import failed please run `pip install replicate`")
|
||||
|
||||
|
||||
# Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN")
|
||||
replicate_key = os.environ.get("REPLICATE_API_TOKEN")
|
||||
|
@ -358,56 +356,25 @@ def completion(
|
|||
)
|
||||
# set replicate key
|
||||
os.environ["REPLICATE_API_TOKEN"] = str(replicate_key)
|
||||
prompt = " ".join([message["content"] for message in messages])
|
||||
input = {
|
||||
"prompt": prompt
|
||||
}
|
||||
if "max_tokens" in optional_params:
|
||||
input["max_length"] = optional_params['max_tokens'] # for t5 models
|
||||
input["max_new_tokens"] = optional_params['max_tokens'] # for llama2 models
|
||||
## LOGGING
|
||||
logging.pre_call(
|
||||
input=prompt,
|
||||
|
||||
model_response = replicate.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding, # for calculating input/output tokens
|
||||
api_key=replicate_key,
|
||||
additional_args={
|
||||
"complete_input_dict": input,
|
||||
"max_tokens": max_tokens,
|
||||
},
|
||||
logging_obj=logging,
|
||||
)
|
||||
## COMPLETION CALL
|
||||
output = replicate.run(model, input=input)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
# let the stream handler know this is replicate
|
||||
response = CustomStreamWrapper(output, "replicate", logging_obj=logging)
|
||||
response = CustomStreamWrapper(model_response, model, logging_obj=logging)
|
||||
return response
|
||||
response = ""
|
||||
for item in output:
|
||||
response += item
|
||||
completion_response = response
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=prompt,
|
||||
api_key=replicate_key,
|
||||
original_response=completion_response,
|
||||
additional_args={
|
||||
"complete_input_dict": input,
|
||||
"max_tokens": max_tokens,
|
||||
},
|
||||
)
|
||||
## USAGE
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(encoding.encode(completion_response))
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = completion_response
|
||||
model_response["created"] = time.time()
|
||||
model_response["model"] = model
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
response = model_response
|
||||
|
||||
elif model in litellm.anthropic_models:
|
||||
anthropic_key = (
|
||||
api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue