fix: bug fix when n>1 passed in

This commit is contained in:
Krrish Dholakia 2023-10-09 16:46:18 -07:00
parent 2004b449e8
commit 253e8d27db
8 changed files with 119 additions and 43 deletions

View file

@ -32,7 +32,7 @@ This list is constantly being updated.
|---|---|---|---|---|---|---|---|---|---|---|
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|Cohere| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | |
|Huggingface| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|Openrouter| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
@ -43,7 +43,7 @@ This list is constantly being updated.
|TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|AlephAlpha| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|Palm| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|NLP Cloud| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|NLP Cloud| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|Petals| ✅ | ✅ | | ✅ | | | | | | |
|Ollama| ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | | |n

View file

@ -1,10 +1,10 @@
import os, types
import os, types, traceback
import json
from enum import Enum
import requests
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse
from litellm.utils import ModelResponse, Choices, Message
import litellm
class AI21Error(Exception):
@ -159,10 +159,14 @@ def completion(
)
else:
try:
model_response["choices"][0]["message"]["content"] = completion_response["completions"][0]["data"]["text"]
model_response.choices[0].finish_reason = completion_response["completions"][0]["finishReason"]["reason"]
choices_list = []
for idx, item in enumerate(completion_response["completions"]):
message_obj = Message(content=item["data"]["text"])
choice_obj = Choices(finish_reason=item["finishReason"]["reason"], index=idx+1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
raise AI21Error(message=json.dumps(completion_response), status_code=response.status_code)
raise AI21Error(message=traceback.format_exc(), status_code=response.status_code)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(

View file

@ -2,9 +2,9 @@ import os, types
import json
from enum import Enum
import requests
import time
import time, traceback
from typing import Callable, Optional
from litellm.utils import ModelResponse
from litellm.utils import ModelResponse, Choices, Message
import litellm
class CohereError(Exception):
@ -156,11 +156,16 @@ def completion(
)
else:
try:
model_response["choices"][0]["message"]["content"] = completion_response["generations"][0]["text"]
except:
raise CohereError(message=json.dumps(completion_response), status_code=response.status_code)
choices_list = []
for idx, item in enumerate(completion_response["generations"]):
message_obj = Message(content=item["text"])
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
raise CohereError(message=traceback.format_exc(), status_code=response.status_code)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
)

View file

@ -1,9 +1,9 @@
import os, types
import os, types, traceback
import json
from enum import Enum
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, get_secret
from litellm.utils import ModelResponse, get_secret, Choices, Message
import litellm
import sys
@ -33,7 +33,7 @@ class PalmConfig():
- `top_p` (float): The API uses combined nucleus and top-k sampling. `top_p` configures the nucleus sampling. It sets the maximum cumulative probability of tokens to sample from.
- `maxOutputTokens` (int): Sets the maximum number of tokens to be returned in the output
- `max_output_tokens` (int): Sets the maximum number of tokens to be returned in the output
"""
context: Optional[str]=None
examples: Optional[list]=None
@ -41,7 +41,7 @@ class PalmConfig():
candidate_count: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
maxOutputTokens: Optional[int]=None
max_output_tokens: Optional[int]=None
def __init__(self,
context: Optional[str]=None,
@ -50,7 +50,7 @@ class PalmConfig():
candidate_count: Optional[int]=None,
top_k: Optional[int]=None,
top_p: Optional[float]=None,
maxOutputTokens: Optional[int]=None) -> None:
max_output_tokens: Optional[int]=None) -> None:
locals_ = locals()
for key, value in locals_.items():
@ -110,10 +110,16 @@ def completion(
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": {}},
additional_args={"complete_input_dict": {"optional_params": optional_params}},
)
## COMPLETION CALL
response = palm.chat(messages=prompt)
try:
response = palm.generate_text(prompt=prompt, **optional_params)
except Exception as e:
raise PalmError(
message=str(e),
status_code=500,
)
## LOGGING
logging_obj.post_call(
@ -124,18 +130,17 @@ def completion(
)
print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT
completion_response = response.last
if "error" in completion_response:
raise PalmError(
message=completion_response["error"],
status_code=response.status_code,
)
else:
try:
model_response["choices"][0]["message"]["content"] = completion_response
except:
raise PalmError(message=json.dumps(completion_response), status_code=response.status_code)
completion_response = response
try:
choices_list = []
for idx, item in enumerate(completion_response.candidates):
message_obj = Message(content=item["output"])
choice_obj = Choices(index=idx+1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
traceback.print_exc()
raise PalmError(message=traceback.format_exc(), status_code=response.status_code)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(

View file

@ -161,7 +161,7 @@ def completion(
raise TogetherAIError(
message=json.dumps(completion_response["output"]), status_code=response.status_code
)
completion_text = completion_response["output"]["choices"][0]["text"]
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.

View file

@ -50,7 +50,7 @@ def claude_test_completion():
try:
# OVERRIDE WITH DYNAMIC MAX TOKENS
response_1 = litellm.completion(
model="together_ai/togethercomputer/llama-2-70b-chat",
model="claude-instant-1",
messages=[{ "content": "Hello, how are you?","role": "user"}],
max_tokens=10
)
@ -60,7 +60,7 @@ def claude_test_completion():
# USE CONFIG TOKENS
response_2 = litellm.completion(
model="together_ai/togethercomputer/llama-2-70b-chat",
model="claude-instant-1",
messages=[{ "content": "Hello, how are you?","role": "user"}],
)
# Add any assertions here to check the response
@ -68,6 +68,14 @@ def claude_test_completion():
response_2_text = response_2.choices[0].message.content
assert len(response_2_text) > len(response_1_text)
try:
response_3 = litellm.completion(model="claude-instant-1",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
except Exception as e:
print(e)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -99,6 +107,12 @@ def replicate_test_completion():
response_2_text = response_2.choices[0].message.content
assert len(response_2_text) > len(response_1_text)
try:
response_3 = litellm.completion(model="meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
except:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -107,8 +121,8 @@ def replicate_test_completion():
# Cohere
def cohere_test_completion():
litellm.CohereConfig(max_tokens=200)
# litellm.set_verbose=True
# litellm.CohereConfig(max_tokens=200)
litellm.set_verbose=True
try:
# OVERRIDE WITH DYNAMIC MAX TOKENS
response_1 = litellm.completion(
@ -126,6 +140,11 @@ def cohere_test_completion():
response_2_text = response_2.choices[0].message.content
assert len(response_2_text) > len(response_1_text)
response_3 = litellm.completion(model="command-nightly",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
assert len(response_3.choices) > 1
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -135,7 +154,7 @@ def cohere_test_completion():
def ai21_test_completion():
litellm.AI21Config(maxTokens=10)
# litellm.set_verbose=True
litellm.set_verbose=True
try:
# OVERRIDE WITH DYNAMIC MAX TOKENS
response_1 = litellm.completion(
@ -155,6 +174,11 @@ def ai21_test_completion():
print(f"response_2_text: {response_2_text}")
assert len(response_2_text) < len(response_1_text)
response_3 = litellm.completion(model="j2-light",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
assert len(response_3.choices) > 1
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -164,7 +188,7 @@ def ai21_test_completion():
def togetherai_test_completion():
litellm.TogetherAIConfig(max_tokens=10)
# litellm.set_verbose=True
litellm.set_verbose=True
try:
# OVERRIDE WITH DYNAMIC MAX TOKENS
response_1 = litellm.completion(
@ -184,6 +208,14 @@ def togetherai_test_completion():
print(f"response_2_text: {response_2_text}")
assert len(response_2_text) < len(response_1_text)
try:
response_3 = litellm.completion(model="together_ai/togethercomputer/llama-2-70b-chat",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
pytest.fail(f"Error not raised when n=2 passed to provider")
except:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -192,7 +224,7 @@ def togetherai_test_completion():
# Palm
def palm_test_completion():
litellm.PalmConfig(maxOutputTokens=10)
litellm.PalmConfig(max_output_tokens=10, temperature=0.9)
# litellm.set_verbose=True
try:
# OVERRIDE WITH DYNAMIC MAX TOKENS
@ -213,6 +245,11 @@ def palm_test_completion():
print(f"response_2_text: {response_2_text}")
assert len(response_2_text) < len(response_1_text)
response_3 = litellm.completion(model="palm/chat-bison",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
assert len(response_3.choices) > 1
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -242,6 +279,14 @@ def nlp_cloud_test_completion():
print(f"response_2_text: {response_2_text}")
assert len(response_2_text) < len(response_1_text)
try:
response_3 = litellm.completion(model="dolphin",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
pytest.fail(f"Error not raised when n=2 passed to provider")
except:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -271,6 +316,14 @@ def aleph_alpha_test_completion():
print(f"response_2_text: {response_2_text}")
assert len(response_2_text) < len(response_1_text)
try:
response_3 = litellm.completion(model="luminous-base",
messages=[{ "content": "Hello, how are you?","role": "user"}],
n=2)
pytest.fail(f"Error not raised when n=2 passed to provider")
except:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -97,6 +97,15 @@ last_fetched_at_keys = None
# 'usage': {'prompt_tokens': 18, 'completion_tokens': 23, 'total_tokens': 41}
# }
class UnsupportedParamsError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
def _generate_id(): # private helper function
return 'chatcmpl-' + str(uuid.uuid4())
@ -1008,7 +1017,7 @@ def get_optional_params( # use the openai defaults
if litellm.add_function_to_prompt: # if user opts to add it to prompt instead
optional_params["functions_unsupported_model"] = non_default_params.pop("functions")
else:
raise ValueError(f"LiteLLM.Exception: Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.")
raise UnsupportedParamsError(status_code=500, message=f"Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.")
def _check_valid_arg(supported_params):
print_verbose(f"checking params for {model}")
@ -1025,7 +1034,7 @@ def get_optional_params( # use the openai defaults
else:
unsupported_params[k] = non_default_params[k]
if unsupported_params and not litellm.drop_params:
raise ValueError(f"LiteLLM.Exception: {custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.")
raise UnsupportedParamsError(status_code=500, message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.")
## raise exception if provider doesn't support passed in param
if custom_llm_provider == "anthropic":
@ -1163,7 +1172,7 @@ def get_optional_params( # use the openai defaults
if stop:
optional_params["stopSequences"] = stop
if max_tokens:
optional_params["maxOutputTokens"] = max_tokens
optional_params["max_output_tokens"] = max_tokens
elif (
custom_llm_provider == "vertex_ai"
):