From a74538ed6efadb2a4480137ea0c60d6fb196aad0 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 17 Aug 2023 18:01:12 -0700 Subject: [PATCH] fix tg computer --- litellm/tests/test_completion.py | 5 ++++- litellm/utils.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 863dc7c459..1158870e74 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -321,8 +321,10 @@ def test_petals(): # import asyncio # def test_completion_together_ai_stream(): +# user_message = "Write 1pg about YC & litellm" +# messages = [{ "content": user_message,"role": "user"}] # try: -# response = completion(model="togethercomputer/llama-2-70b-chat", messages=messages, custom_llm_provider="together_ai", stream=True, max_tokens=200) +# response = completion(model="togethercomputer/llama-2-70b-chat", messages=messages, stream=True, max_tokens=800) # print(response) # asyncio.run(get_response(response)) # # print(string_response) @@ -335,4 +337,5 @@ def test_petals(): # print(elem) # return +# test_completion_together_ai_stream() diff --git a/litellm/utils.py b/litellm/utils.py index b46b126101..e6816cced3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -392,7 +392,7 @@ def get_optional_params( if stream: optional_params["stream"] = stream return optional_params - elif custom_llm_provider == "together_ai": + elif custom_llm_provider == "together_ai" or ("togethercomputer" in model): if stream: optional_params["stream_tokens"] = stream if temperature != 1: @@ -897,7 +897,7 @@ class CustomStreamWrapper: elif self.model == "replicate": chunk = next(self.completion_stream) completion_obj["content"] = chunk - elif self.model == "together_ai": + elif (self.model == "together_ai") or ("togethercomputer" in self.model): chunk = next(self.completion_stream) text_data = self.handle_together_ai_chunk(chunk) if text_data == "":