mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix: bug fix when n>1 passed in
This commit is contained in:
parent
2004b449e8
commit
253e8d27db
8 changed files with 119 additions and 43 deletions
|
@ -32,7 +32,7 @@ This list is constantly being updated.
|
|||
|---|---|---|---|---|---|---|---|---|---|---|
|
||||
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||
|OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|
||||
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||
|Cohere| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | |
|
||||
|Huggingface| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|
||||
|Openrouter| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
|
@ -43,7 +43,7 @@ This list is constantly being updated.
|
|||
|TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||
|AlephAlpha| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|
||||
|Palm| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|
||||
|NLP Cloud| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|
||||
|NLP Cloud| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||
|Petals| ✅ | ✅ | | ✅ | | | | | | |
|
||||
|Ollama| ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | | |n
|
||||
|
||||
|
|
Binary file not shown.
|
@ -1,10 +1,10 @@
|
|||
import os, types
|
||||
import os, types, traceback
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse
|
||||
from litellm.utils import ModelResponse, Choices, Message
|
||||
import litellm
|
||||
|
||||
class AI21Error(Exception):
|
||||
|
@ -159,10 +159,14 @@ def completion(
|
|||
)
|
||||
else:
|
||||
try:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response["completions"][0]["data"]["text"]
|
||||
model_response.choices[0].finish_reason = completion_response["completions"][0]["finishReason"]["reason"]
|
||||
choices_list = []
|
||||
for idx, item in enumerate(completion_response["completions"]):
|
||||
message_obj = Message(content=item["data"]["text"])
|
||||
choice_obj = Choices(finish_reason=item["finishReason"]["reason"], index=idx+1, message=message_obj)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except Exception as e:
|
||||
raise AI21Error(message=json.dumps(completion_response), status_code=response.status_code)
|
||||
raise AI21Error(message=traceback.format_exc(), status_code=response.status_code)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(
|
||||
|
|
|
@ -2,9 +2,9 @@ import os, types
|
|||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import time
|
||||
import time, traceback
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse
|
||||
from litellm.utils import ModelResponse, Choices, Message
|
||||
import litellm
|
||||
|
||||
class CohereError(Exception):
|
||||
|
@ -156,11 +156,16 @@ def completion(
|
|||
)
|
||||
else:
|
||||
try:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response["generations"][0]["text"]
|
||||
except:
|
||||
raise CohereError(message=json.dumps(completion_response), status_code=response.status_code)
|
||||
choices_list = []
|
||||
for idx, item in enumerate(completion_response["generations"]):
|
||||
message_obj = Message(content=item["text"])
|
||||
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except Exception as e:
|
||||
raise CohereError(message=traceback.format_exc(), status_code=response.status_code)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import os, types
|
||||
import os, types, traceback
|
||||
import json
|
||||
from enum import Enum
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, get_secret
|
||||
from litellm.utils import ModelResponse, get_secret, Choices, Message
|
||||
import litellm
|
||||
import sys
|
||||
|
||||
|
@ -33,7 +33,7 @@ class PalmConfig():
|
|||
|
||||
- `top_p` (float): The API uses combined nucleus and top-k sampling. `top_p` configures the nucleus sampling. It sets the maximum cumulative probability of tokens to sample from.
|
||||
|
||||
- `maxOutputTokens` (int): Sets the maximum number of tokens to be returned in the output
|
||||
- `max_output_tokens` (int): Sets the maximum number of tokens to be returned in the output
|
||||
"""
|
||||
context: Optional[str]=None
|
||||
examples: Optional[list]=None
|
||||
|
@ -41,7 +41,7 @@ class PalmConfig():
|
|||
candidate_count: Optional[int]=None
|
||||
top_k: Optional[int]=None
|
||||
top_p: Optional[float]=None
|
||||
maxOutputTokens: Optional[int]=None
|
||||
max_output_tokens: Optional[int]=None
|
||||
|
||||
def __init__(self,
|
||||
context: Optional[str]=None,
|
||||
|
@ -50,7 +50,7 @@ class PalmConfig():
|
|||
candidate_count: Optional[int]=None,
|
||||
top_k: Optional[int]=None,
|
||||
top_p: Optional[float]=None,
|
||||
maxOutputTokens: Optional[int]=None) -> None:
|
||||
max_output_tokens: Optional[int]=None) -> None:
|
||||
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
|
@ -110,10 +110,16 @@ def completion(
|
|||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": {}},
|
||||
additional_args={"complete_input_dict": {"optional_params": optional_params}},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = palm.chat(messages=prompt)
|
||||
try:
|
||||
response = palm.generate_text(prompt=prompt, **optional_params)
|
||||
except Exception as e:
|
||||
raise PalmError(
|
||||
message=str(e),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -124,18 +130,17 @@ def completion(
|
|||
)
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
## RESPONSE OBJECT
|
||||
completion_response = response.last
|
||||
|
||||
if "error" in completion_response:
|
||||
raise PalmError(
|
||||
message=completion_response["error"],
|
||||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response
|
||||
except:
|
||||
raise PalmError(message=json.dumps(completion_response), status_code=response.status_code)
|
||||
completion_response = response
|
||||
try:
|
||||
choices_list = []
|
||||
for idx, item in enumerate(completion_response.candidates):
|
||||
message_obj = Message(content=item["output"])
|
||||
choice_obj = Choices(index=idx+1, message=message_obj)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise PalmError(message=traceback.format_exc(), status_code=response.status_code)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(
|
||||
|
|
|
@ -161,7 +161,7 @@ def completion(
|
|||
raise TogetherAIError(
|
||||
message=json.dumps(completion_response["output"]), status_code=response.status_code
|
||||
)
|
||||
|
||||
|
||||
completion_text = completion_response["output"]["choices"][0]["text"]
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
|
|
|
@ -50,7 +50,7 @@ def claude_test_completion():
|
|||
try:
|
||||
# OVERRIDE WITH DYNAMIC MAX TOKENS
|
||||
response_1 = litellm.completion(
|
||||
model="together_ai/togethercomputer/llama-2-70b-chat",
|
||||
model="claude-instant-1",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
max_tokens=10
|
||||
)
|
||||
|
@ -60,7 +60,7 @@ def claude_test_completion():
|
|||
|
||||
# USE CONFIG TOKENS
|
||||
response_2 = litellm.completion(
|
||||
model="together_ai/togethercomputer/llama-2-70b-chat",
|
||||
model="claude-instant-1",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
|
@ -68,6 +68,14 @@ def claude_test_completion():
|
|||
response_2_text = response_2.choices[0].message.content
|
||||
|
||||
assert len(response_2_text) > len(response_1_text)
|
||||
|
||||
try:
|
||||
response_3 = litellm.completion(model="claude-instant-1",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
n=2)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -99,6 +107,12 @@ def replicate_test_completion():
|
|||
response_2_text = response_2.choices[0].message.content
|
||||
|
||||
assert len(response_2_text) > len(response_1_text)
|
||||
try:
|
||||
response_3 = litellm.completion(model="meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
n=2)
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -107,8 +121,8 @@ def replicate_test_completion():
|
|||
# Cohere
|
||||
|
||||
def cohere_test_completion():
|
||||
litellm.CohereConfig(max_tokens=200)
|
||||
# litellm.set_verbose=True
|
||||
# litellm.CohereConfig(max_tokens=200)
|
||||
litellm.set_verbose=True
|
||||
try:
|
||||
# OVERRIDE WITH DYNAMIC MAX TOKENS
|
||||
response_1 = litellm.completion(
|
||||
|
@ -126,6 +140,11 @@ def cohere_test_completion():
|
|||
response_2_text = response_2.choices[0].message.content
|
||||
|
||||
assert len(response_2_text) > len(response_1_text)
|
||||
|
||||
response_3 = litellm.completion(model="command-nightly",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
n=2)
|
||||
assert len(response_3.choices) > 1
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -135,7 +154,7 @@ def cohere_test_completion():
|
|||
|
||||
def ai21_test_completion():
|
||||
litellm.AI21Config(maxTokens=10)
|
||||
# litellm.set_verbose=True
|
||||
litellm.set_verbose=True
|
||||
try:
|
||||
# OVERRIDE WITH DYNAMIC MAX TOKENS
|
||||
response_1 = litellm.completion(
|
||||
|
@ -155,6 +174,11 @@ def ai21_test_completion():
|
|||
print(f"response_2_text: {response_2_text}")
|
||||
|
||||
assert len(response_2_text) < len(response_1_text)
|
||||
|
||||
response_3 = litellm.completion(model="j2-light",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
n=2)
|
||||
assert len(response_3.choices) > 1
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -164,7 +188,7 @@ def ai21_test_completion():
|
|||
|
||||
def togetherai_test_completion():
|
||||
litellm.TogetherAIConfig(max_tokens=10)
|
||||
# litellm.set_verbose=True
|
||||
litellm.set_verbose=True
|
||||
try:
|
||||
# OVERRIDE WITH DYNAMIC MAX TOKENS
|
||||
response_1 = litellm.completion(
|
||||
|
@ -184,6 +208,14 @@ def togetherai_test_completion():
|
|||
print(f"response_2_text: {response_2_text}")
|
||||
|
||||
assert len(response_2_text) < len(response_1_text)
|
||||
|
||||
try:
|
||||
response_3 = litellm.completion(model="together_ai/togethercomputer/llama-2-70b-chat",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
n=2)
|
||||
pytest.fail(f"Error not raised when n=2 passed to provider")
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -192,7 +224,7 @@ def togetherai_test_completion():
|
|||
# Palm
|
||||
|
||||
def palm_test_completion():
|
||||
litellm.PalmConfig(maxOutputTokens=10)
|
||||
litellm.PalmConfig(max_output_tokens=10, temperature=0.9)
|
||||
# litellm.set_verbose=True
|
||||
try:
|
||||
# OVERRIDE WITH DYNAMIC MAX TOKENS
|
||||
|
@ -213,6 +245,11 @@ def palm_test_completion():
|
|||
print(f"response_2_text: {response_2_text}")
|
||||
|
||||
assert len(response_2_text) < len(response_1_text)
|
||||
|
||||
response_3 = litellm.completion(model="palm/chat-bison",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
n=2)
|
||||
assert len(response_3.choices) > 1
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -242,6 +279,14 @@ def nlp_cloud_test_completion():
|
|||
print(f"response_2_text: {response_2_text}")
|
||||
|
||||
assert len(response_2_text) < len(response_1_text)
|
||||
|
||||
try:
|
||||
response_3 = litellm.completion(model="dolphin",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
n=2)
|
||||
pytest.fail(f"Error not raised when n=2 passed to provider")
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -271,6 +316,14 @@ def aleph_alpha_test_completion():
|
|||
print(f"response_2_text: {response_2_text}")
|
||||
|
||||
assert len(response_2_text) < len(response_1_text)
|
||||
|
||||
try:
|
||||
response_3 = litellm.completion(model="luminous-base",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
n=2)
|
||||
pytest.fail(f"Error not raised when n=2 passed to provider")
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
|
@ -97,6 +97,15 @@ last_fetched_at_keys = None
|
|||
# 'usage': {'prompt_tokens': 18, 'completion_tokens': 23, 'total_tokens': 41}
|
||||
# }
|
||||
|
||||
class UnsupportedParamsError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
def _generate_id(): # private helper function
|
||||
return 'chatcmpl-' + str(uuid.uuid4())
|
||||
|
||||
|
@ -1008,7 +1017,7 @@ def get_optional_params( # use the openai defaults
|
|||
if litellm.add_function_to_prompt: # if user opts to add it to prompt instead
|
||||
optional_params["functions_unsupported_model"] = non_default_params.pop("functions")
|
||||
else:
|
||||
raise ValueError(f"LiteLLM.Exception: Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.")
|
||||
raise UnsupportedParamsError(status_code=500, message=f"Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.")
|
||||
|
||||
def _check_valid_arg(supported_params):
|
||||
print_verbose(f"checking params for {model}")
|
||||
|
@ -1025,7 +1034,7 @@ def get_optional_params( # use the openai defaults
|
|||
else:
|
||||
unsupported_params[k] = non_default_params[k]
|
||||
if unsupported_params and not litellm.drop_params:
|
||||
raise ValueError(f"LiteLLM.Exception: {custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.")
|
||||
raise UnsupportedParamsError(status_code=500, message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.")
|
||||
|
||||
## raise exception if provider doesn't support passed in param
|
||||
if custom_llm_provider == "anthropic":
|
||||
|
@ -1163,7 +1172,7 @@ def get_optional_params( # use the openai defaults
|
|||
if stop:
|
||||
optional_params["stopSequences"] = stop
|
||||
if max_tokens:
|
||||
optional_params["maxOutputTokens"] = max_tokens
|
||||
optional_params["max_output_tokens"] = max_tokens
|
||||
elif (
|
||||
custom_llm_provider == "vertex_ai"
|
||||
):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue