add oobabooga text web api support

This commit is contained in:
Krrish Dholakia 2023-09-19 18:56:53 -07:00
parent 31fc90b239
commit ecad921b01
8 changed files with 160 additions and 1 deletions

View file

@ -277,6 +277,7 @@ provider_list: List = [
"nlp_cloud",
"bedrock",
"petals",
"oobabooga",
"custom", # custom apis
]

123
litellm/llms/oobabooga.py Normal file
View file

@ -0,0 +1,123 @@
import os
import json
from enum import Enum
import requests
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse
from .prompt_templates.factory import prompt_factory, custom_prompt
class OobaboogaError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
def validate_environment(api_key):
headers = {
"accept": "application/json",
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Token {api_key}"
return headers
def completion(
model: str,
messages: list,
api_base: Optional[str],
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None,
default_max_tokens_to_sample=None,
):
headers = validate_environment(api_key)
if "https" in model:
completion_url = model
elif api_base:
completion_url = api_base
else:
raise OobaboogaError(status_code=404, message="API Base not set. Set one via completion(..,api_base='your-api-url')")
model = model
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
)
else:
prompt = prompt_factory(model=model, messages=messages)
completion_url = completion_url + "/api/v1/generate"
data = {
"prompt": prompt,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise OobaboogaError(message=response.text, status_code=response.status_code)
if "error" in completion_response:
raise OobaboogaError(
message=completion_response["error"],
status_code=response.status_code,
)
else:
try:
model_response["choices"][0]["message"]["content"] = completion_response['results'][0]['text']
except:
raise OobaboogaError(message=json.dumps(completion_response), status_code=response.status_code)
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
)
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -35,6 +35,7 @@ from .llms import vllm
from .llms import ollama
from .llms import cohere
from .llms import petals
from .llms import oobabooga
import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict
@ -686,6 +687,28 @@ def completion(
)
return response
response = model_response
elif custom_llm_provider == "oobabooga":
custom_llm_provider = "oobabooga"
model_response = oobabooga.completion(
model=model,
messages=messages,
model_response=model_response,
api_base=api_base, # type: ignore
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
api_key=None,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="oobabooga", logging_obj=logging
)
return response
response = model_response
elif custom_llm_provider == "together_ai" or ("togethercomputer" in model) or (model in litellm.together_ai_models):
custom_llm_provider = "together_ai"
together_ai_key = (

View file

@ -47,6 +47,18 @@ def test_completion_claude():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_claude()
# def test_completion_oobabooga():
# try:
# response = completion(
# model="oobabooga/vicuna-1.3b", messages=messages, api_base="http://127.0.0.1:5000"
# )
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_oobabooga()
# aleph alpha
# def test_completion_aleph_alpha():
# try:

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.704"
version = "0.1.705"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"