diff --git a/litellm/__init__.py b/litellm/__init__.py index cc6f88c0c8..a4faaf2321 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -277,6 +277,7 @@ provider_list: List = [ "nlp_cloud", "bedrock", "petals", + "oobabooga", "custom", # custom apis ] diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index 083f785552..2e536cdea6 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 3dd1f609b3..21214a8dec 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index b5c44d665d..bd50ebc532 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/llms/oobabooga.py b/litellm/llms/oobabooga.py new file mode 100644 index 0000000000..e49eba4228 --- /dev/null +++ b/litellm/llms/oobabooga.py @@ -0,0 +1,123 @@ +import os +import json +from enum import Enum +import requests +import time +from typing import Callable, Optional +from litellm.utils import ModelResponse +from .prompt_templates.factory import prompt_factory, custom_prompt + +class OobaboogaError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + +def validate_environment(api_key): + headers = { + "accept": "application/json", + "content-type": "application/json", + } + if api_key: + headers["Authorization"] = f"Token {api_key}" + return headers + +def completion( + model: str, + messages: list, + api_base: Optional[str], + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + custom_prompt_dict={}, + optional_params=None, + litellm_params=None, + logger_fn=None, + default_max_tokens_to_sample=None, +): + headers = validate_environment(api_key) + if "https" in model: + completion_url = model + elif api_base: + completion_url = api_base + else: + raise OobaboogaError(status_code=404, message="API Base not set. Set one via completion(..,api_base='your-api-url')") + model = model + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages + ) + else: + prompt = prompt_factory(model=model, messages=messages) + + completion_url = completion_url + "/api/v1/generate" + data = { + "prompt": prompt, + **optional_params, + } + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + ## COMPLETION CALL + response = requests.post( + completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False + ) + if "stream" in optional_params and optional_params["stream"] == True: + return response.iter_lines() + else: + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + try: + completion_response = response.json() + except: + raise OobaboogaError(message=response.text, status_code=response.status_code) + if "error" in completion_response: + raise OobaboogaError( + message=completion_response["error"], + status_code=response.status_code, + ) + else: + try: + model_response["choices"][0]["message"]["content"] = completion_response['results'][0]['text'] + except: + raise OobaboogaError(message=json.dumps(completion_response), status_code=response.status_code) + + ## CALCULATING USAGE + prompt_tokens = len( + encoding.encode(prompt) + ) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"]["content"]) + ) + + model_response["created"] = time.time() + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + return model_response + +def embedding(): + # logic for parsing in - calling - parsing out model embedding calls + pass diff --git a/litellm/main.py b/litellm/main.py index b9acd6e4af..ad25662528 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -35,6 +35,7 @@ from .llms import vllm from .llms import ollama from .llms import cohere from .llms import petals +from .llms import oobabooga import tiktoken from concurrent.futures import ThreadPoolExecutor from typing import Callable, List, Optional, Dict @@ -686,6 +687,28 @@ def completion( ) return response response = model_response + elif custom_llm_provider == "oobabooga": + custom_llm_provider = "oobabooga" + model_response = oobabooga.completion( + model=model, + messages=messages, + model_response=model_response, + api_base=api_base, # type: ignore + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + api_key=None, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging + ) + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, model, custom_llm_provider="oobabooga", logging_obj=logging + ) + return response + response = model_response elif custom_llm_provider == "together_ai" or ("togethercomputer" in model) or (model in litellm.together_ai_models): custom_llm_provider = "together_ai" together_ai_key = ( diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 3bfd3956db..fca0e550b0 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -47,6 +47,18 @@ def test_completion_claude(): except Exception as e: pytest.fail(f"Error occurred: {e}") # test_completion_claude() + +# def test_completion_oobabooga(): +# try: +# response = completion( +# model="oobabooga/vicuna-1.3b", messages=messages, api_base="http://127.0.0.1:5000" +# ) +# # Add any assertions here to check the response +# print(response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") + +# test_completion_oobabooga() # aleph alpha # def test_completion_aleph_alpha(): # try: diff --git a/pyproject.toml b/pyproject.toml index d37c08138f..73f7a058ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.704" +version = "0.1.705" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"