remove init for together_ai completion calls

This commit is contained in:
ishaan-jaff 2023-09-04 09:59:24 -07:00
parent 46857577fa
commit f2b0fa90ab
2 changed files with 101 additions and 107 deletions

View file

@ -1,11 +1,11 @@
import os, json import os
import json
from enum import Enum from enum import Enum
import requests import requests
import time import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
class TogetherAIError(Exception): class TogetherAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -14,118 +14,110 @@ class TogetherAIError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def validate_environment(api_key):
if api_key is None:
raise ValueError(
"Missing TogetherAI API Key - A call is being made to together_ai but no key is set either in the environment variables or via params"
)
headers = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": "Bearer " + api_key,
}
return headers
class TogetherAILLM: def completion(
def __init__(self, encoding, logging_obj, api_key=None): model: str,
self.encoding = encoding messages: list,
self.completion_url = "https://api.together.xyz/inference" model_response: ModelResponse,
self.api_key = api_key print_verbose: Callable,
self.logging_obj = logging_obj encoding,
self.validate_environment(api_key=api_key) api_key,
logging_obj,
def validate_environment( optional_params=None,
self, api_key litellm_params=None,
): # set up the environment required to run the model logger_fn=None,
# set the api key ):
if self.api_key == None: headers = validate_environment(api_key)
raise ValueError( model = model
"Missing TogetherAI API Key - A call is being made to together_ai but no key is set either in the environment variables or via params" prompt = ""
) for message in messages:
self.api_key = api_key if "role" in message:
self.headers = { if message["role"] == "user":
"accept": "application/json", prompt += f"{message['content']}"
"content-type": "application/json",
"Authorization": "Bearer " + self.api_key,
}
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
optional_params=None,
litellm_params=None,
logger_fn=None,
): # logic for parsing in - calling - parsing out model completion calls
model = model
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
else: else:
prompt += f"{message['content']}" prompt += f"{message['content']}"
data = { else:
"model": model, prompt += f"{message['content']}"
"prompt": prompt, data = {
"request_type": "language-model-inference", "model": model,
**optional_params, "prompt": prompt,
} "request_type": "language-model-inference",
**optional_params,
}
## LOGGING ## LOGGING
self.logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=self.api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
if ( if (
"stream_tokens" in optional_params "stream_tokens" in optional_params
and optional_params["stream_tokens"] == True and optional_params["stream_tokens"] == True
): ):
response = requests.post( response = requests.post(
self.completion_url, "https://api.together.xyz/inference",
headers=self.headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
stream=optional_params["stream_tokens"], stream=optional_params["stream_tokens"],
) )
return response.iter_lines() return response.iter_lines()
else: else:
response = requests.post( response = requests.post(
self.completion_url, "https://api.together.xyz/inference",
headers=self.headers, headers=headers,
data=json.dumps(data) data=json.dumps(data)
) )
## LOGGING ## LOGGING
self.logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key=self.api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = response.json() completion_response = response.json()
if "error" in completion_response: if "error" in completion_response:
raise TogetherAIError( raise TogetherAIError(
message=json.dumps(completion_response), message=json.dumps(completion_response),
status_code=response.status_code, status_code=response.status_code,
) )
elif "error" in completion_response["output"]: elif "error" in completion_response["output"]:
raise TogetherAIError(message=json.dumps(completion_response["output"]), status_code=response.status_code) raise TogetherAIError(
message=json.dumps(completion_response["output"]), status_code=response.status_code
completion_response = completion_response["output"]["choices"][0]["text"]
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(self.encoding.encode(prompt))
completion_tokens = len(
self.encoding.encode(completion_response)
) )
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
return model_response
def embedding( completion_response = completion_response["output"]["choices"][0]["text"]
self,
): # logic for parsing in - calling - parsing out model embedding calls ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
pass prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(completion_response)
)
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -20,10 +20,10 @@ from litellm.utils import (
completion_with_fallbacks, completion_with_fallbacks,
) )
from .llms import anthropic from .llms import anthropic
from .llms import together_ai
from .llms.huggingface_restapi import HuggingfaceRestAPILLM from .llms.huggingface_restapi import HuggingfaceRestAPILLM
from .llms.baseten import BasetenLLM from .llms.baseten import BasetenLLM
from .llms.ai21 import AI21LLM from .llms.ai21 import AI21LLM
from .llms.together_ai import TogetherAILLM
from .llms.aleph_alpha import AlephAlphaLLM from .llms.aleph_alpha import AlephAlphaLLM
import tiktoken import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -578,9 +578,8 @@ def completion(
or get_secret("TOGETHER_AI_TOKEN") or get_secret("TOGETHER_AI_TOKEN")
or get_secret("TOGETHERAI_API_KEY") or get_secret("TOGETHERAI_API_KEY")
) )
together_ai_client = TogetherAILLM(encoding=encoding, api_key=together_ai_key, logging_obj=logging) model_response = together_ai.completion(
model_response = together_ai_client.completion(
model=model, model=model,
messages=messages, messages=messages,
model_response=model_response, model_response=model_response,
@ -588,6 +587,9 @@ def completion(
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding,
api_key=together_ai_key,
logging_obj=logging
) )
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
# don't try to access stream object, # don't try to access stream object,