forked from phoenix/litellm-mirror
add replicate support for max_tokens
This commit is contained in:
parent
ef43141554
commit
d4c4a138ca
4 changed files with 11 additions and 4 deletions
|
@ -108,10 +108,9 @@ def completion(
|
||||||
version_id = model_to_version_id(model)
|
version_id = model_to_version_id(model)
|
||||||
input_data = {
|
input_data = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"max_new_tokens": 50,
|
**optional_params
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -142,6 +141,9 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
print_verbose(f"raw model_response: {result}")
|
print_verbose(f"raw model_response: {result}")
|
||||||
|
|
||||||
|
if len(result) == 0: # edge case, where result from replicate is empty
|
||||||
|
result = " "
|
||||||
|
|
||||||
## Building RESPONSE OBJECT
|
## Building RESPONSE OBJECT
|
||||||
model_response["choices"][0]["message"]["content"] = result
|
model_response["choices"][0]["message"]["content"] = result
|
||||||
|
|
|
@ -352,10 +352,13 @@ def test_completion_azure_deployment_id():
|
||||||
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
|
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
|
||||||
|
|
||||||
def test_completion_replicate_llama_2():
|
def test_completion_replicate_llama_2():
|
||||||
|
litellm.set_verbose = True
|
||||||
model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf"
|
model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf"
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = completion(
|
||||||
model=model_name, messages=messages, custom_llm_provider="replicate"
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
custom_llm_provider="replicate"
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
|
|
|
@ -707,6 +707,8 @@ def get_optional_params( # use the openai defaults
|
||||||
if stream:
|
if stream:
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
return optional_params
|
return optional_params
|
||||||
|
if max_tokens != float("inf"):
|
||||||
|
optional_params["max_new_tokens"] = max_tokens
|
||||||
elif custom_llm_provider == "together_ai" or ("togethercomputer" in model):
|
elif custom_llm_provider == "together_ai" or ("togethercomputer" in model):
|
||||||
if stream:
|
if stream:
|
||||||
optional_params["stream_tokens"] = stream
|
optional_params["stream_tokens"] = stream
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "0.1.539"
|
version = "0.1.540"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue