fix tg computer

This commit is contained in:
ishaan-jaff 2023-08-17 18:01:12 -07:00
parent 068930f81a
commit 2ea36c49e3
2 changed files with 6 additions and 3 deletions

View file

@ -321,8 +321,10 @@ def test_petals():
# import asyncio # import asyncio
# def test_completion_together_ai_stream(): # def test_completion_together_ai_stream():
# user_message = "Write 1pg about YC & litellm"
# messages = [{ "content": user_message,"role": "user"}]
# try: # try:
# response = completion(model="togethercomputer/llama-2-70b-chat", messages=messages, custom_llm_provider="together_ai", stream=True, max_tokens=200) # response = completion(model="togethercomputer/llama-2-70b-chat", messages=messages, stream=True, max_tokens=800)
# print(response) # print(response)
# asyncio.run(get_response(response)) # asyncio.run(get_response(response))
# # print(string_response) # # print(string_response)
@ -335,4 +337,5 @@ def test_petals():
# print(elem) # print(elem)
# return # return
# test_completion_together_ai_stream()

View file

@ -392,7 +392,7 @@ def get_optional_params(
if stream: if stream:
optional_params["stream"] = stream optional_params["stream"] = stream
return optional_params return optional_params
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai" or ("togethercomputer" in model):
if stream: if stream:
optional_params["stream_tokens"] = stream optional_params["stream_tokens"] = stream
if temperature != 1: if temperature != 1:
@ -897,7 +897,7 @@ class CustomStreamWrapper:
elif self.model == "replicate": elif self.model == "replicate":
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
completion_obj["content"] = chunk completion_obj["content"] = chunk
elif self.model == "together_ai": elif (self.model == "together_ai") or ("togethercomputer" in self.model):
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
text_data = self.handle_together_ai_chunk(chunk) text_data = self.handle_together_ai_chunk(chunk)
if text_data == "": if text_data == "":