add replicate support for max_tokens

This commit is contained in:
ishaan-jaff 2023-09-06 10:38:19 -07:00
parent ef43141554
commit d4c4a138ca
4 changed files with 11 additions and 4 deletions

View file

@ -108,10 +108,9 @@ def completion(
version_id = model_to_version_id(model)
input_data = {
"prompt": prompt,
"max_new_tokens": 50,
**optional_params
}
## LOGGING
logging_obj.pre_call(
input=prompt,
@ -142,6 +141,9 @@ def completion(
)
print_verbose(f"raw model_response: {result}")
if len(result) == 0: # edge case, where result from replicate is empty
result = " "
## Building RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = result

View file

@ -352,10 +352,13 @@ def test_completion_azure_deployment_id():
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
def test_completion_replicate_llama_2():
litellm.set_verbose = True
model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf"
try:
response = completion(
model=model_name, messages=messages, custom_llm_provider="replicate"
model=model_name,
messages=messages,
custom_llm_provider="replicate"
)
print(response)
# Add any assertions here to check the response

View file

@ -707,6 +707,8 @@ def get_optional_params( # use the openai defaults
if stream:
optional_params["stream"] = stream
return optional_params
if max_tokens != float("inf"):
optional_params["max_new_tokens"] = max_tokens
elif custom_llm_provider == "together_ai" or ("togethercomputer" in model):
if stream:
optional_params["stream_tokens"] = stream

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.539"
version = "0.1.540"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"