diff --git a/litellm/main.py b/litellm/main.py index 81dbc76b5..56613c8c8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -375,9 +375,13 @@ def completion( "model": model, "prompt": prompt, "request_type": "language-model-inference", + **optional_params }, headers=headers ) + if stream == True: + response = CustomStreamWrapper(res, "together_ai") + return response completion_response = res.json()['output']['choices'][0]['text'] diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 9aa475d4c..9583ea03f 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -200,10 +200,21 @@ def test_completion_replicate_stability(): ######## Test TogetherAI ######## def test_completion_together_ai(): - model_name = "togethercomputer/mpt-30b-chat" + model_name = "togethercomputer/llama-2-70b-chat" try: response = completion(model=model_name, messages=messages, together_ai=True) # Add any assertions here to check the response print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +def test_completion_together_ai_stream(): + model_name = "togethercomputer/llama-2-70b-chat" + try: + response = completion(model=model_name, messages=messages, together_ai=True, stream=True) + # Add any assertions here to check the response + print(response) + for chunk in response: + print(chunk['choices'][0]['delta']) # same as openai format except Exception as e: pytest.fail(f"Error occurred: {e}") \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 8369b0202..22fb381b7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -247,7 +247,7 @@ def get_optional_params( return optional_params elif together_ai == True: if stream: - optional_params["stream"] = stream + optional_params["stream_tokens"] = stream if temperature != 1: optional_params["temperature"] = temperature if top_p != 1: @@ -652,12 +652,25 @@ class CustomStreamWrapper: if model in litellm.cohere_models: # cohere does not return an iterator, so we need to wrap it in one self.completion_stream = iter(completion_stream) + elif model == "together_ai": + self.completion_stream = iter(completion_stream) else: self.completion_stream = completion_stream def __iter__(self): return self + def handle_together_ai_chunk(self, chunk): + chunk = chunk.decode("utf-8") + text_index = chunk.find('"text":"') # this checks if text: exists + text_start = text_index + len('"text":"') + text_end = chunk.find('"}', text_start) + if text_index != -1 and text_end != -1: + extracted_text = chunk[text_start:text_end] + return extracted_text + else: + return "" + def __next__(self): completion_obj ={ "role": "assistant", "content": ""} if self.model in litellm.anthropic_models: @@ -666,9 +679,14 @@ class CustomStreamWrapper: elif self.model == "replicate": chunk = next(self.completion_stream) completion_obj["content"] = chunk + elif self.model == "together_ai": + chunk = next(self.completion_stream) + text_data = self.handle_together_ai_chunk(chunk) + if text_data == "": + return self.__next__() + completion_obj["content"] = text_data elif self.model in litellm.cohere_models: chunk = next(self.completion_stream) completion_obj["content"] = chunk.text # return this for all models return {"choices": [{"delta": completion_obj}]} -