fix: bug fix when n>1 passed in

This commit is contained in:
Krrish Dholakia 2023-10-09 16:46:18 -07:00
parent 2004b449e8
commit 253e8d27db
8 changed files with 119 additions and 43 deletions

View file

@ -32,7 +32,7 @@ This list is constantly being updated.
|---|---|---|---|---|---|---|---|---|---|---| |---|---|---|---|---|---|---|---|---|---|---|
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|Cohere| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | |Cohere| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | |
|Huggingface| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |Huggingface| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|Openrouter| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |Openrouter| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
@ -43,7 +43,7 @@ This list is constantly being updated.
|TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|AlephAlpha| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |AlephAlpha| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|Palm| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |Palm| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|NLP Cloud| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |NLP Cloud| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|Petals| ✅ | ✅ | | ✅ | | | | | | | |Petals| ✅ | ✅ | | ✅ | | | | | | |
|Ollama| ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | | |n |Ollama| ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | | |n

View file

@ -1,10 +1,10 @@
import os, types import os, types, traceback
import json import json
from enum import Enum from enum import Enum
import requests import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse from litellm.utils import ModelResponse, Choices, Message
import litellm import litellm
class AI21Error(Exception): class AI21Error(Exception):
@ -159,10 +159,14 @@ def completion(
) )
else: else:
try: try:
model_response["choices"][0]["message"]["content"] = completion_response["completions"][0]["data"]["text"] choices_list = []
model_response.choices[0].finish_reason = completion_response["completions"][0]["finishReason"]["reason"] for idx, item in enumerate(completion_response["completions"]):
message_obj = Message(content=item["data"]["text"])
choice_obj = Choices(finish_reason=item["finishReason"]["reason"], index=idx+1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e: except Exception as e:
raise AI21Error(message=json.dumps(completion_response), status_code=response.status_code) raise AI21Error(message=traceback.format_exc(), status_code=response.status_code)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(

View file

@ -2,9 +2,9 @@ import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests
import time import time, traceback
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse from litellm.utils import ModelResponse, Choices, Message
import litellm import litellm
class CohereError(Exception): class CohereError(Exception):
@ -156,11 +156,16 @@ def completion(
) )
else: else:
try: try:
model_response["choices"][0]["message"]["content"] = completion_response["generations"][0]["text"] choices_list = []
except: for idx, item in enumerate(completion_response["generations"]):
raise CohereError(message=json.dumps(completion_response), status_code=response.status_code) message_obj = Message(content=item["text"])
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
raise CohereError(message=traceback.format_exc(), status_code=response.status_code)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE
prompt_tokens = len( prompt_tokens = len(
encoding.encode(prompt) encoding.encode(prompt)
) )

View file

@ -1,9 +1,9 @@
import os, types import os, types, traceback
import json import json
from enum import Enum from enum import Enum
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, get_secret from litellm.utils import ModelResponse, get_secret, Choices, Message
import litellm import litellm
import sys import sys
@ -33,7 +33,7 @@ class PalmConfig():
- `top_p` (float): The API uses combined nucleus and top-k sampling. `top_p` configures the nucleus sampling. It sets the maximum cumulative probability of tokens to sample from. - `top_p` (float): The API uses combined nucleus and top-k sampling. `top_p` configures the nucleus sampling. It sets the maximum cumulative probability of tokens to sample from.
- `maxOutputTokens` (int): Sets the maximum number of tokens to be returned in the output - `max_output_tokens` (int): Sets the maximum number of tokens to be returned in the output
""" """
context: Optional[str]=None context: Optional[str]=None
examples: Optional[list]=None examples: Optional[list]=None
@ -41,7 +41,7 @@ class PalmConfig():
candidate_count: Optional[int]=None candidate_count: Optional[int]=None
top_k: Optional[int]=None top_k: Optional[int]=None
top_p: Optional[float]=None top_p: Optional[float]=None
maxOutputTokens: Optional[int]=None max_output_tokens: Optional[int]=None
def __init__(self, def __init__(self,
context: Optional[str]=None, context: Optional[str]=None,
@ -50,7 +50,7 @@ class PalmConfig():
candidate_count: Optional[int]=None, candidate_count: Optional[int]=None,
top_k: Optional[int]=None, top_k: Optional[int]=None,
top_p: Optional[float]=None, top_p: Optional[float]=None,
maxOutputTokens: Optional[int]=None) -> None: max_output_tokens: Optional[int]=None) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
@ -110,10 +110,16 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": {}}, additional_args={"complete_input_dict": {"optional_params": optional_params}},
) )
## COMPLETION CALL ## COMPLETION CALL
response = palm.chat(messages=prompt) try:
response = palm.generate_text(prompt=prompt, **optional_params)
except Exception as e:
raise PalmError(
message=str(e),
status_code=500,
)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -124,18 +130,17 @@ def completion(
) )
print_verbose(f"raw model_response: {response}") print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = response.last completion_response = response
try:
if "error" in completion_response: choices_list = []
raise PalmError( for idx, item in enumerate(completion_response.candidates):
message=completion_response["error"], message_obj = Message(content=item["output"])
status_code=response.status_code, choice_obj = Choices(index=idx+1, message=message_obj)
) choices_list.append(choice_obj)
else: model_response["choices"] = choices_list
try: except Exception as e:
model_response["choices"][0]["message"]["content"] = completion_response traceback.print_exc()
except: raise PalmError(message=traceback.format_exc(), status_code=response.status_code)
raise PalmError(message=json.dumps(completion_response), status_code=response.status_code)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(

View file

@ -161,7 +161,7 @@ def completion(
raise TogetherAIError( raise TogetherAIError(
message=json.dumps(completion_response["output"]), status_code=response.status_code message=json.dumps(completion_response["output"]), status_code=response.status_code
) )
completion_text = completion_response["output"]["choices"][0]["text"] completion_text = completion_response["output"]["choices"][0]["text"]
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.

View file

@ -50,7 +50,7 @@ def claude_test_completion():
try: try:
# OVERRIDE WITH DYNAMIC MAX TOKENS # OVERRIDE WITH DYNAMIC MAX TOKENS
response_1 = litellm.completion( response_1 = litellm.completion(
model="together_ai/togethercomputer/llama-2-70b-chat", model="claude-instant-1",
messages=[{ "content": "Hello, how are you?","role": "user"}], messages=[{ "content": "Hello, how are you?","role": "user"}],
max_tokens=10 max_tokens=10
) )
@ -60,7 +60,7 @@ def claude_test_completion():
# USE CONFIG TOKENS # USE CONFIG TOKENS
response_2 = litellm.completion( response_2 = litellm.completion(
model="together_ai/togethercomputer/llama-2-70b-chat", model="claude-instant-1",
messages=[{ "content": "Hello, how are you?","role": "user"}], messages=[{ "content": "Hello, how are you?","role": "user"}],
) )
# Add any assertions here to check the response # Add any assertions here to check the response
@ -68,6 +68,14 @@ def claude_test_completion():
response_2_text = response_2.choices[0].message.content response_2_text = response_2.choices[0].message.content
assert len(response_2_text) > len(response_1_text) assert len(response_2_text) > len(response_1_text)
try:
response_3 = litellm.completion(model="claude-instant-1",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
except Exception as e:
print(e)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -99,6 +107,12 @@ def replicate_test_completion():
response_2_text = response_2.choices[0].message.content response_2_text = response_2.choices[0].message.content
assert len(response_2_text) > len(response_1_text) assert len(response_2_text) > len(response_1_text)
try:
response_3 = litellm.completion(model="meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
except:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -107,8 +121,8 @@ def replicate_test_completion():
# Cohere # Cohere
def cohere_test_completion(): def cohere_test_completion():
litellm.CohereConfig(max_tokens=200) # litellm.CohereConfig(max_tokens=200)
# litellm.set_verbose=True litellm.set_verbose=True
try: try:
# OVERRIDE WITH DYNAMIC MAX TOKENS # OVERRIDE WITH DYNAMIC MAX TOKENS
response_1 = litellm.completion( response_1 = litellm.completion(
@ -126,6 +140,11 @@ def cohere_test_completion():
response_2_text = response_2.choices[0].message.content response_2_text = response_2.choices[0].message.content
assert len(response_2_text) > len(response_1_text) assert len(response_2_text) > len(response_1_text)
response_3 = litellm.completion(model="command-nightly",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
assert len(response_3.choices) > 1
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -135,7 +154,7 @@ def cohere_test_completion():
def ai21_test_completion(): def ai21_test_completion():
litellm.AI21Config(maxTokens=10) litellm.AI21Config(maxTokens=10)
# litellm.set_verbose=True litellm.set_verbose=True
try: try:
# OVERRIDE WITH DYNAMIC MAX TOKENS # OVERRIDE WITH DYNAMIC MAX TOKENS
response_1 = litellm.completion( response_1 = litellm.completion(
@ -155,6 +174,11 @@ def ai21_test_completion():
print(f"response_2_text: {response_2_text}") print(f"response_2_text: {response_2_text}")
assert len(response_2_text) < len(response_1_text) assert len(response_2_text) < len(response_1_text)
response_3 = litellm.completion(model="j2-light",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
assert len(response_3.choices) > 1
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -164,7 +188,7 @@ def ai21_test_completion():
def togetherai_test_completion(): def togetherai_test_completion():
litellm.TogetherAIConfig(max_tokens=10) litellm.TogetherAIConfig(max_tokens=10)
# litellm.set_verbose=True litellm.set_verbose=True
try: try:
# OVERRIDE WITH DYNAMIC MAX TOKENS # OVERRIDE WITH DYNAMIC MAX TOKENS
response_1 = litellm.completion( response_1 = litellm.completion(
@ -184,6 +208,14 @@ def togetherai_test_completion():
print(f"response_2_text: {response_2_text}") print(f"response_2_text: {response_2_text}")
assert len(response_2_text) < len(response_1_text) assert len(response_2_text) < len(response_1_text)
try:
response_3 = litellm.completion(model="together_ai/togethercomputer/llama-2-70b-chat",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
pytest.fail(f"Error not raised when n=2 passed to provider")
except:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -192,7 +224,7 @@ def togetherai_test_completion():
# Palm # Palm
def palm_test_completion(): def palm_test_completion():
litellm.PalmConfig(maxOutputTokens=10) litellm.PalmConfig(max_output_tokens=10, temperature=0.9)
# litellm.set_verbose=True # litellm.set_verbose=True
try: try:
# OVERRIDE WITH DYNAMIC MAX TOKENS # OVERRIDE WITH DYNAMIC MAX TOKENS
@ -213,6 +245,11 @@ def palm_test_completion():
print(f"response_2_text: {response_2_text}") print(f"response_2_text: {response_2_text}")
assert len(response_2_text) < len(response_1_text) assert len(response_2_text) < len(response_1_text)
response_3 = litellm.completion(model="palm/chat-bison",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
assert len(response_3.choices) > 1
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -242,6 +279,14 @@ def nlp_cloud_test_completion():
print(f"response_2_text: {response_2_text}") print(f"response_2_text: {response_2_text}")
assert len(response_2_text) < len(response_1_text) assert len(response_2_text) < len(response_1_text)
try:
response_3 = litellm.completion(model="dolphin",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
pytest.fail(f"Error not raised when n=2 passed to provider")
except:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -271,6 +316,14 @@ def aleph_alpha_test_completion():
print(f"response_2_text: {response_2_text}") print(f"response_2_text: {response_2_text}")
assert len(response_2_text) < len(response_1_text) assert len(response_2_text) < len(response_1_text)
try:
response_3 = litellm.completion(model="luminous-base",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
pytest.fail(f"Error not raised when n=2 passed to provider")
except:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -97,6 +97,15 @@ last_fetched_at_keys = None
# 'usage': {'prompt_tokens': 18, 'completion_tokens': 23, 'total_tokens': 41} # 'usage': {'prompt_tokens': 18, 'completion_tokens': 23, 'total_tokens': 41}
# } # }
class UnsupportedParamsError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
def _generate_id(): # private helper function def _generate_id(): # private helper function
return 'chatcmpl-' + str(uuid.uuid4()) return 'chatcmpl-' + str(uuid.uuid4())
@ -1008,7 +1017,7 @@ def get_optional_params( # use the openai defaults
if litellm.add_function_to_prompt: # if user opts to add it to prompt instead if litellm.add_function_to_prompt: # if user opts to add it to prompt instead
optional_params["functions_unsupported_model"] = non_default_params.pop("functions") optional_params["functions_unsupported_model"] = non_default_params.pop("functions")
else: else:
raise ValueError(f"LiteLLM.Exception: Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.") raise UnsupportedParamsError(status_code=500, message=f"Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.")
def _check_valid_arg(supported_params): def _check_valid_arg(supported_params):
print_verbose(f"checking params for {model}") print_verbose(f"checking params for {model}")
@ -1025,7 +1034,7 @@ def get_optional_params( # use the openai defaults
else: else:
unsupported_params[k] = non_default_params[k] unsupported_params[k] = non_default_params[k]
if unsupported_params and not litellm.drop_params: if unsupported_params and not litellm.drop_params:
raise ValueError(f"LiteLLM.Exception: {custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.") raise UnsupportedParamsError(status_code=500, message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.")
## raise exception if provider doesn't support passed in param ## raise exception if provider doesn't support passed in param
if custom_llm_provider == "anthropic": if custom_llm_provider == "anthropic":
@ -1163,7 +1172,7 @@ def get_optional_params( # use the openai defaults
if stop: if stop:
optional_params["stopSequences"] = stop optional_params["stopSequences"] = stop
if max_tokens: if max_tokens:
optional_params["maxOutputTokens"] = max_tokens optional_params["max_output_tokens"] = max_tokens
elif ( elif (
custom_llm_provider == "vertex_ai" custom_llm_provider == "vertex_ai"
): ):