baseten client mapping

This commit is contained in:
ishaan-jaff 2023-09-04 15:41:36 -07:00
parent 3147bf1d99
commit db4f4c0191
2 changed files with 118 additions and 131 deletions

View file

@ -26,7 +26,7 @@ from .llms import sagemaker
from .llms import bedrock
from .llms import huggingface_restapi
from .llms import aleph_alpha
from .llms.baseten import BasetenLLM
from .llms import baseten
import tiktoken
from concurrent.futures import ThreadPoolExecutor
@ -751,10 +751,8 @@ def completion(
baseten_key = (
api_key or litellm.baseten_key or os.environ.get("BASETEN_API_KEY")
)
baseten_client = BasetenLLM(
encoding=encoding, api_key=baseten_key, logging_obj=logging
)
model_response = baseten_client.completion(
model_response = baseten.completion(
model=model,
messages=messages,
model_response=model_response,
@ -762,6 +760,9 @@ def completion(
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=baseten_key,
logging_obj=logging
)
if inspect.isgenerator(model_response) or ("stream" in optional_params and optional_params["stream"] == True):
# don't try to access stream object,