remove init for together_ai completion calls

This commit is contained in:
ishaan-jaff 2023-09-04 09:59:24 -07:00
parent 46857577fa
commit f2b0fa90ab
2 changed files with 101 additions and 107 deletions

View file

@ -1,11 +1,11 @@
import os, json import os
import json
from enum import Enum from enum import Enum
import requests import requests
import time import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
class TogetherAIError(Exception): class TogetherAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -14,40 +14,31 @@ class TogetherAIError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def validate_environment(api_key):
class TogetherAILLM: if api_key is None:
def __init__(self, encoding, logging_obj, api_key=None):
self.encoding = encoding
self.completion_url = "https://api.together.xyz/inference"
self.api_key = api_key
self.logging_obj = logging_obj
self.validate_environment(api_key=api_key)
def validate_environment(
self, api_key
): # set up the environment required to run the model
# set the api key
if self.api_key == None:
raise ValueError( raise ValueError(
"Missing TogetherAI API Key - A call is being made to together_ai but no key is set either in the environment variables or via params" "Missing TogetherAI API Key - A call is being made to together_ai but no key is set either in the environment variables or via params"
) )
self.api_key = api_key headers = {
self.headers = {
"accept": "application/json", "accept": "application/json",
"content-type": "application/json", "content-type": "application/json",
"Authorization": "Bearer " + self.api_key, "Authorization": "Bearer " + api_key,
} }
return headers
def completion( def completion(
self,
model: str, model: str,
messages: list, messages: list,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): # logic for parsing in - calling - parsing out model completion calls ):
headers = validate_environment(api_key)
model = model model = model
prompt = "" prompt = ""
for message in messages: for message in messages:
@ -66,9 +57,9 @@ class TogetherAILLM:
} }
## LOGGING ## LOGGING
self.logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=self.api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
@ -77,22 +68,22 @@ class TogetherAILLM:
and optional_params["stream_tokens"] == True and optional_params["stream_tokens"] == True
): ):
response = requests.post( response = requests.post(
self.completion_url, "https://api.together.xyz/inference",
headers=self.headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
stream=optional_params["stream_tokens"], stream=optional_params["stream_tokens"],
) )
return response.iter_lines() return response.iter_lines()
else: else:
response = requests.post( response = requests.post(
self.completion_url, "https://api.together.xyz/inference",
headers=self.headers, headers=headers,
data=json.dumps(data) data=json.dumps(data)
) )
## LOGGING ## LOGGING
self.logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key=self.api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
@ -106,14 +97,16 @@ class TogetherAILLM:
status_code=response.status_code, status_code=response.status_code,
) )
elif "error" in completion_response["output"]: elif "error" in completion_response["output"]:
raise TogetherAIError(message=json.dumps(completion_response["output"]), status_code=response.status_code) raise TogetherAIError(
message=json.dumps(completion_response["output"]), status_code=response.status_code
)
completion_response = completion_response["output"]["choices"][0]["text"] completion_response = completion_response["output"]["choices"][0]["text"]
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(self.encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len( completion_tokens = len(
self.encoding.encode(completion_response) encoding.encode(completion_response)
) )
model_response["choices"][0]["message"]["content"] = completion_response model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time() model_response["created"] = time.time()
@ -125,7 +118,6 @@ class TogetherAILLM:
} }
return model_response return model_response
def embedding( def embedding():
self, # logic for parsing in - calling - parsing out model embedding calls
): # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -20,10 +20,10 @@ from litellm.utils import (
completion_with_fallbacks, completion_with_fallbacks,
) )
from .llms import anthropic from .llms import anthropic
from .llms import together_ai
from .llms.huggingface_restapi import HuggingfaceRestAPILLM from .llms.huggingface_restapi import HuggingfaceRestAPILLM
from .llms.baseten import BasetenLLM from .llms.baseten import BasetenLLM
from .llms.ai21 import AI21LLM from .llms.ai21 import AI21LLM
from .llms.together_ai import TogetherAILLM
from .llms.aleph_alpha import AlephAlphaLLM from .llms.aleph_alpha import AlephAlphaLLM
import tiktoken import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -579,8 +579,7 @@ def completion(
or get_secret("TOGETHERAI_API_KEY") or get_secret("TOGETHERAI_API_KEY")
) )
together_ai_client = TogetherAILLM(encoding=encoding, api_key=together_ai_key, logging_obj=logging) model_response = together_ai.completion(
model_response = together_ai_client.completion(
model=model, model=model,
messages=messages, messages=messages,
model_response=model_response, model_response=model_response,
@ -588,6 +587,9 @@ def completion(
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding,
api_key=together_ai_key,
logging_obj=logging
) )
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
# don't try to access stream object, # don't try to access stream object,