diff --git a/litellm/main.py b/litellm/main.py index 18453e0c7..f8dbfa1c1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -12,7 +12,7 @@ import tiktoken from concurrent.futures import ThreadPoolExecutor encoding = tiktoken.get_encoding("cl100k_base") from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, read_config_args -from litellm.utils import get_ollama_response_stream, stream_to_string +from litellm.utils import get_ollama_response_stream, stream_to_string, together_ai_completion_streaming ####### ENVIRONMENT VARIABLES ################### dotenv.load_dotenv() # Loading env variables using dotenv new_response = { @@ -321,9 +321,17 @@ def completion( headers = {"Authorization": f"Bearer {TOGETHER_AI_TOKEN}"} endpoint = 'https://api.together.xyz/inference' prompt = " ".join([message["content"] for message in messages]) # TODO: Add chat support for together AI - + ## LOGGING logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn) + if stream == True or optional_params['stream_tokens'] == True: + return together_ai_completion_streaming({ + "model": model, + "prompt": prompt, + "request_type": "language-model-inference", + **optional_params + }, + headers=headers) res = requests.post(endpoint, json={ "model": model, "prompt": prompt, @@ -334,9 +342,6 @@ def completion( ) ## LOGGING logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": res.text}, logger_fn=logger_fn) - if stream == True: - response = CustomStreamWrapper(res, "together_ai") - return response completion_response = res.json()['output']['choices'][0]['text'] prompt_tokens = len(encoding.encode(prompt)) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 8c0946131..1025c0a1d 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -294,4 +294,20 @@ def test_petals(): # pytest.fail(f"Error occurred: {e}") +# import asyncio +# def test_completion_together_ai_stream(): +# try: +# response = completion(model="togethercomputer/llama-2-70b-chat", messages=messages, custom_llm_provider="together_ai", stream=True, max_tokens=200) +# print(response) +# asyncio.run(get_response(response)) +# # print(string_response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") + + +# async def get_response(generator): +# async for elem in generator: +# print(elem) +# return + diff --git a/litellm/utils.py b/litellm/utils.py index 6d828b26a..e5f803e31 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -856,4 +856,42 @@ async def stream_to_string(generator): response += chunk["content"] return response + +########## Together AI streaming ############################# +async def together_ai_completion_streaming(json_data, headers): + session = aiohttp.ClientSession() + url = 'https://api.together.xyz/inference' + # headers = { + # 'Authorization': f'Bearer {together_ai_token}', + # 'Content-Type': 'application/json' + # } + + # data = { + # "model": "togethercomputer/llama-2-70b-chat", + # "prompt": "write 1 page on the topic of the history of the united state", + # "max_tokens": 1000, + # "temperature": 0.7, + # "top_p": 0.7, + # "top_k": 50, + # "repetition_penalty": 1, + # "stream_tokens": True + # } + try: + async with session.post(url, json=json_data, headers=headers) as resp: + async for line in resp.content.iter_any(): + # print(line) + if line: + try: + json_chunk = line.decode("utf-8") + json_string = json_chunk.split('data: ')[1] + # Convert the JSON string to a dictionary + data_dict = json.loads(json_string) + completion_response = data_dict['choices'][0]['text'] + completion_obj ={ "role": "assistant", "content": ""} + completion_obj["content"] = completion_response + yield {"choices": [{"delta": completion_obj}]} + except: + pass + finally: + await session.close()