add streaming for together-ai

This commit is contained in:
ishaan-jaff 2023-08-09 17:58:54 -07:00
parent 23fc936b65
commit 3c3d144584
3 changed files with 36 additions and 3 deletions

View file

@ -375,9 +375,13 @@ def completion(
"model": model, "model": model,
"prompt": prompt, "prompt": prompt,
"request_type": "language-model-inference", "request_type": "language-model-inference",
**optional_params
}, },
headers=headers headers=headers
) )
if stream == True:
response = CustomStreamWrapper(res, "together_ai")
return response
completion_response = res.json()['output']['choices'][0]['text'] completion_response = res.json()['output']['choices'][0]['text']

View file

@ -200,10 +200,21 @@ def test_completion_replicate_stability():
######## Test TogetherAI ######## ######## Test TogetherAI ########
def test_completion_together_ai(): def test_completion_together_ai():
model_name = "togethercomputer/mpt-30b-chat" model_name = "togethercomputer/llama-2-70b-chat"
try: try:
response = completion(model=model_name, messages=messages, together_ai=True) response = completion(model=model_name, messages=messages, together_ai=True)
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_together_ai_stream():
model_name = "togethercomputer/llama-2-70b-chat"
try:
response = completion(model=model_name, messages=messages, together_ai=True, stream=True)
# Add any assertions here to check the response
print(response)
for chunk in response:
print(chunk['choices'][0]['delta']) # same as openai format
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -247,7 +247,7 @@ def get_optional_params(
return optional_params return optional_params
elif together_ai == True: elif together_ai == True:
if stream: if stream:
optional_params["stream"] = stream optional_params["stream_tokens"] = stream
if temperature != 1: if temperature != 1:
optional_params["temperature"] = temperature optional_params["temperature"] = temperature
if top_p != 1: if top_p != 1:
@ -652,12 +652,25 @@ class CustomStreamWrapper:
if model in litellm.cohere_models: if model in litellm.cohere_models:
# cohere does not return an iterator, so we need to wrap it in one # cohere does not return an iterator, so we need to wrap it in one
self.completion_stream = iter(completion_stream) self.completion_stream = iter(completion_stream)
elif model == "together_ai":
self.completion_stream = iter(completion_stream)
else: else:
self.completion_stream = completion_stream self.completion_stream = completion_stream
def __iter__(self): def __iter__(self):
return self return self
def handle_together_ai_chunk(self, chunk):
chunk = chunk.decode("utf-8")
text_index = chunk.find('"text":"') # this checks if text: exists
text_start = text_index + len('"text":"')
text_end = chunk.find('"}', text_start)
if text_index != -1 and text_end != -1:
extracted_text = chunk[text_start:text_end]
return extracted_text
else:
return ""
def __next__(self): def __next__(self):
completion_obj ={ "role": "assistant", "content": ""} completion_obj ={ "role": "assistant", "content": ""}
if self.model in litellm.anthropic_models: if self.model in litellm.anthropic_models:
@ -666,9 +679,14 @@ class CustomStreamWrapper:
elif self.model == "replicate": elif self.model == "replicate":
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
completion_obj["content"] = chunk completion_obj["content"] = chunk
elif self.model == "together_ai":
chunk = next(self.completion_stream)
text_data = self.handle_together_ai_chunk(chunk)
if text_data == "":
return self.__next__()
completion_obj["content"] = text_data
elif self.model in litellm.cohere_models: elif self.model in litellm.cohere_models:
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
completion_obj["content"] = chunk.text completion_obj["content"] = chunk.text
# return this for all models # return this for all models
return {"choices": [{"delta": completion_obj}]} return {"choices": [{"delta": completion_obj}]}