diff --git a/litellm/main.py b/litellm/main.py index 59261fc86..81dbc76b5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -44,7 +44,7 @@ def completion( presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None, # Optional liteLLM function params *, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False, - hugging_face = False, replicate=False, + hugging_face = False, replicate=False,together_ai = False ): try: global new_response @@ -55,7 +55,7 @@ def completion( temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id, # params to identify the model - model=model, replicate=replicate, hugging_face=hugging_face + model=model, replicate=replicate, hugging_face=hugging_face, together_ai=together_ai ) if azure == True: # azure configs @@ -362,6 +362,40 @@ def completion( "total_tokens": prompt_tokens + completion_tokens } response = model_response + elif together_ai == True: + import requests + TOGETHER_AI_TOKEN = get_secret("TOGETHER_AI_TOKEN") + headers = {"Authorization": f"Bearer {TOGETHER_AI_TOKEN}"} + endpoint = 'https://api.together.xyz/inference' + prompt = " ".join([message["content"] for message in messages]) # TODO: Add chat support for together AI + + ## LOGGING + logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn) + res = requests.post(endpoint, json={ + "model": model, + "prompt": prompt, + "request_type": "language-model-inference", + }, + headers=headers + ) + + completion_response = res.json()['output']['choices'][0]['text'] + + ## LOGGING + logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) + prompt_tokens = len(encoding.encode(prompt)) + completion_tokens = len(encoding.encode(completion_response)) + ## RESPONSE OBJECT + model_response["choices"][0]["message"]["content"] = completion_response + model_response["created"] = time.time() + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens + } + response = model_response + else: ## LOGGING logging(model=model, input=messages, azure=azure, logger_fn=logger_fn) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 28ebba367..9aa475d4c 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -26,7 +26,6 @@ def test_completion_claude(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_claude() def test_completion_claude_stream(): try: messages = [ @@ -40,7 +39,6 @@ def test_completion_claude_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_claude_stream() def test_completion_hf_api(): try: user_message = "write some code to find the sum of two numbers" @@ -197,5 +195,15 @@ def test_completion_replicate_stability(): for result in response: print(result) print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +######## Test TogetherAI ######## +def test_completion_together_ai(): + model_name = "togethercomputer/mpt-30b-chat" + try: + response = completion(model=model_name, messages=messages, together_ai=True) + # Add any assertions here to check the response + print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 04bef0709..8369b0202 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -216,6 +216,7 @@ def get_optional_params( model = None, replicate = False, hugging_face = False, + together_ai = False, ): optional_params = {} if model in litellm.anthropic_models: @@ -244,6 +245,18 @@ def get_optional_params( if stream: optional_params["stream"] = stream return optional_params + elif together_ai == True: + if stream: + optional_params["stream"] = stream + if temperature != 1: + optional_params["temperature"] = temperature + if top_p != 1: + optional_params["top_p"] = top_p + if max_tokens != float('inf'): + optional_params["max_tokens"] = max_tokens + if frequency_penalty != 0: + optional_params["frequency_penalty"] = frequency_penalty + else:# assume passing in params for openai/azure openai if functions != []: optional_params["functions"] = functions