forked from phoenix/litellm-mirror
add together-ai
This commit is contained in:
parent
441fff406c
commit
23fc936b65
3 changed files with 59 additions and 4 deletions
|
@ -44,7 +44,7 @@ def completion(
|
||||||
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
|
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
|
||||||
# Optional liteLLM function params
|
# Optional liteLLM function params
|
||||||
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False,
|
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False,
|
||||||
hugging_face = False, replicate=False,
|
hugging_face = False, replicate=False,together_ai = False
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
global new_response
|
global new_response
|
||||||
|
@ -55,7 +55,7 @@ def completion(
|
||||||
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
|
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
|
||||||
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
|
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
|
||||||
# params to identify the model
|
# params to identify the model
|
||||||
model=model, replicate=replicate, hugging_face=hugging_face
|
model=model, replicate=replicate, hugging_face=hugging_face, together_ai=together_ai
|
||||||
)
|
)
|
||||||
if azure == True:
|
if azure == True:
|
||||||
# azure configs
|
# azure configs
|
||||||
|
@ -362,6 +362,40 @@ def completion(
|
||||||
"total_tokens": prompt_tokens + completion_tokens
|
"total_tokens": prompt_tokens + completion_tokens
|
||||||
}
|
}
|
||||||
response = model_response
|
response = model_response
|
||||||
|
elif together_ai == True:
|
||||||
|
import requests
|
||||||
|
TOGETHER_AI_TOKEN = get_secret("TOGETHER_AI_TOKEN")
|
||||||
|
headers = {"Authorization": f"Bearer {TOGETHER_AI_TOKEN}"}
|
||||||
|
endpoint = 'https://api.together.xyz/inference'
|
||||||
|
prompt = " ".join([message["content"] for message in messages]) # TODO: Add chat support for together AI
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
|
||||||
|
res = requests.post(endpoint, json={
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"request_type": "language-model-inference",
|
||||||
|
},
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_response = res.json()['output']['choices'][0]['text']
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||||
|
prompt_tokens = len(encoding.encode(prompt))
|
||||||
|
completion_tokens = len(encoding.encode(completion_response))
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
model_response["choices"][0]["message"]["content"] = completion_response
|
||||||
|
model_response["created"] = time.time()
|
||||||
|
model_response["model"] = model
|
||||||
|
model_response["usage"] = {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens
|
||||||
|
}
|
||||||
|
response = model_response
|
||||||
|
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
|
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
|
||||||
|
|
|
@ -26,7 +26,6 @@ def test_completion_claude():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
test_completion_claude()
|
|
||||||
def test_completion_claude_stream():
|
def test_completion_claude_stream():
|
||||||
try:
|
try:
|
||||||
messages = [
|
messages = [
|
||||||
|
@ -40,7 +39,6 @@ def test_completion_claude_stream():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
test_completion_claude_stream()
|
|
||||||
def test_completion_hf_api():
|
def test_completion_hf_api():
|
||||||
try:
|
try:
|
||||||
user_message = "write some code to find the sum of two numbers"
|
user_message = "write some code to find the sum of two numbers"
|
||||||
|
@ -199,3 +197,13 @@ def test_completion_replicate_stability():
|
||||||
print(response)
|
print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
######## Test TogetherAI ########
|
||||||
|
def test_completion_together_ai():
|
||||||
|
model_name = "togethercomputer/mpt-30b-chat"
|
||||||
|
try:
|
||||||
|
response = completion(model=model_name, messages=messages, together_ai=True)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
|
@ -216,6 +216,7 @@ def get_optional_params(
|
||||||
model = None,
|
model = None,
|
||||||
replicate = False,
|
replicate = False,
|
||||||
hugging_face = False,
|
hugging_face = False,
|
||||||
|
together_ai = False,
|
||||||
):
|
):
|
||||||
optional_params = {}
|
optional_params = {}
|
||||||
if model in litellm.anthropic_models:
|
if model in litellm.anthropic_models:
|
||||||
|
@ -244,6 +245,18 @@ def get_optional_params(
|
||||||
if stream:
|
if stream:
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
return optional_params
|
return optional_params
|
||||||
|
elif together_ai == True:
|
||||||
|
if stream:
|
||||||
|
optional_params["stream"] = stream
|
||||||
|
if temperature != 1:
|
||||||
|
optional_params["temperature"] = temperature
|
||||||
|
if top_p != 1:
|
||||||
|
optional_params["top_p"] = top_p
|
||||||
|
if max_tokens != float('inf'):
|
||||||
|
optional_params["max_tokens"] = max_tokens
|
||||||
|
if frequency_penalty != 0:
|
||||||
|
optional_params["frequency_penalty"] = frequency_penalty
|
||||||
|
|
||||||
else:# assume passing in params for openai/azure openai
|
else:# assume passing in params for openai/azure openai
|
||||||
if functions != []:
|
if functions != []:
|
||||||
optional_params["functions"] = functions
|
optional_params["functions"] = functions
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue