diff --git a/litellm/__init__.py b/litellm/__init__.py index e118273cf..e66e40dc6 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -191,6 +191,7 @@ provider_list = [ "azure", "sagemaker", "bedrock", + "custom", # custom apis ] models_by_provider = { diff --git a/litellm/main.py b/litellm/main.py index 1637e4950..e60ace66b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -546,7 +546,11 @@ def completion( } response = model_response elif ( - model in litellm.huggingface_models or custom_llm_provider == "huggingface" + ( + model in litellm.huggingface_models and + custom_llm_provider!="custom" # if users use a hf model, with a custom/provider. See implementation of custom_llm_provider == custom + ) or + custom_llm_provider == "huggingface" ): custom_llm_provider = "huggingface" huggingface_key = ( @@ -783,6 +787,62 @@ def completion( ) return response response = model_response + elif ( + custom_llm_provider == "custom" + ): + import requests + url = ( + litellm.api_base or + api_base + ) + """ + assume input to custom LLM api bases follow this format: + resp = requests.post( + api_base, + json={ + 'model': 'meta-llama/Llama-2-13b-hf', # model name + 'params': { + 'prompt': ["The capital of France is P"], + 'max_tokens': 32, + 'temperature': 0.7, + 'top_p': 1.0, + 'top_k': 40, + } + } + ) + + """ + prompt = " ".join([message["content"] for message in messages]) + resp = requests.post(url, json={ + 'model': model, + 'params': { + 'prompt': [prompt], + 'max_tokens': max_tokens, + 'temperature': temperature, + 'top_p': top_p, + 'top_k': top_k, + } + }) + resp = resp.json() + """ + assume all responses from custom api_bases of this format: + { + 'data': [ + { + 'prompt': 'The capital of France is P', + 'output': ['The capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France'], + 'params': {'temperature': 0.7, 'top_k': 40, 'top_p': 1}}], + 'message': 'ok' + } + ] + } + """ + string_response = resp['data'][0]['output'][0] + ## RESPONSE OBJECT + model_response["choices"][0]["message"]["content"] = string_response + model_response["created"] = time.time() + model_response["model"] = model + response = model_response else: raise ValueError( f"Unable to map your input to a model. Check your input - {args}" diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index eb0ab5819..448716247 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -435,6 +435,22 @@ def test_completion_sagemaker(): except Exception as e: pytest.fail(f"Error occurred: {e}") + +def test_completion_custom_api_base(): + try: + response = completion( + model="custom/meta-llama/Llama-2-13b-hf", + messages=messages, + temperature=0.2, + max_tokens=10, + api_base="https://api.autoai.dev/inference", + ) + # Add any assertions here to check the response + print("got response\n", response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") +# test_completion_custom_api_base() + # def test_vertex_ai(): # model_name = "chat-bison" # try: