use api_base instead of custom_api_base

This commit is contained in:
ishaan-jaff 2023-09-02 17:11:30 -07:00
parent e6836985c8
commit 09ae510a58
9 changed files with 39 additions and 39 deletions

View file

@ -88,7 +88,7 @@
} }
], ],
"source": [ "source": [
"response = completion(model=\"llama2\", messages=messages, custom_api_base=\"http://localhost:11434\", custom_llm_provider=\"ollama\", stream=True)\n", "response = completion(model=\"llama2\", messages=messages, api_base=\"http://localhost:11434\", custom_llm_provider=\"ollama\", stream=True)\n",
"print(response)" "print(response)"
] ]
}, },

View file

@ -178,12 +178,12 @@ Ollama supported models: https://github.com/jmorganca/ollama
| Model Name | Function Call | Required OS Variables | | Model Name | Function Call | Required OS Variables |
|----------------------|-----------------------------------------------------------------------------------|--------------------------------| |----------------------|-----------------------------------------------------------------------------------|--------------------------------|
| Llama2 7B | `completion(model='llama2', messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required | | Llama2 7B | `completion(model='llama2', messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required |
| Llama2 13B | `completion(model='llama2:13b', messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required | | Llama2 13B | `completion(model='llama2:13b', messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required |
| Llama2 70B | `completion(model='llama2:70b', messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required | | Llama2 70B | `completion(model='llama2:70b', messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required |
| Llama2 Uncensored | `completion(model='llama2-uncensored', messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required | | Llama2 Uncensored | `completion(model='llama2-uncensored', messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required |
| Orca Mini | `completion(model='orca-mini', messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required | | Orca Mini | `completion(model='orca-mini', messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required |
| Vicuna | `completion(model='vicuna', messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required | | Vicuna | `completion(model='vicuna', messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required |
| Nous-Hermes | `completion(model='nous-hermes', messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required | | Nous-Hermes | `completion(model='nous-hermes', messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required |
| Nous-Hermes 13B | `completion(model='nous-hermes:13b', messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required | | Nous-Hermes 13B | `completion(model='nous-hermes:13b', messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required |
| Wizard Vicuna Uncensored | `completion(model='wizard-vicuna', messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required | | Wizard Vicuna Uncensored | `completion(model='wizard-vicuna', messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)` | No API Key required |

View file

@ -41,14 +41,14 @@ from litellm import completion
model = "meta-llama/Llama-2-7b-hf" model = "meta-llama/Llama-2-7b-hf"
messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format
custom_api_base = "https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud" api_base = "https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud"
### CALLING ENDPOINT ### CALLING ENDPOINT
completion(model=model, messages=messages, custom_llm_provider="huggingface", custom_api_base=custom_api_base) completion(model=model, messages=messages, custom_llm_provider="huggingface", api_base=api_base)
``` ```
What's happening? What's happening?
- custom_api_base: Optional param. Since this uses a deployed endpoint (not the [default huggingface inference endpoint](https://github.com/BerriAI/litellm/blob/6aff47083be659b80e00cb81eb783cb24db2e183/litellm/llms/huggingface_restapi.py#L35)), we pass that to LiteLLM. - api_base: Optional param. Since this uses a deployed endpoint (not the [default huggingface inference endpoint](https://github.com/BerriAI/litellm/blob/6aff47083be659b80e00cb81eb783cb24db2e183/litellm/llms/huggingface_restapi.py#L35)), we pass that to LiteLLM.
### Case 3: Call Llama2 private Huggingface endpoint ### Case 3: Call Llama2 private Huggingface endpoint
@ -72,10 +72,10 @@ os.environ["HF_TOKEN] = "..."
model = "meta-llama/Llama-2-7b-hf" model = "meta-llama/Llama-2-7b-hf"
messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format
custom_api_base = "https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud" api_base = "https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud"
### CALLING ENDPOINT ### CALLING ENDPOINT
completion(model=model, messages=messages, custom_llm_provider="huggingface", custom_api_base=custom_api_base) completion(model=model, messages=messages, custom_llm_provider="huggingface", api_base=api_base)
``` ```
**Setting it as package variable** **Setting it as package variable**
@ -93,10 +93,10 @@ litellm.huggingface_key = "..."
model = "meta-llama/Llama-2-7b-hf" model = "meta-llama/Llama-2-7b-hf"
messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format
custom_api_base = "https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud" api_base = "https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud"
### CALLING ENDPOINT ### CALLING ENDPOINT
completion(model=model, messages=messages, custom_llm_provider="huggingface", custom_api_base=custom_api_base) completion(model=model, messages=messages, custom_llm_provider="huggingface", api_base=api_base)
``` ```
**Passed in during completion call** **Passed in during completion call**
@ -111,8 +111,8 @@ from litellm import completion
model = "meta-llama/Llama-2-7b-hf" model = "meta-llama/Llama-2-7b-hf"
messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format
custom_api_base = "https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud" api_base = "https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud"
### CALLING ENDPOINT ### CALLING ENDPOINT
completion(model=model, messages=messages, custom_llm_provider="huggingface", custom_api_base=custom_api_base, api_key="...") completion(model=model, messages=messages, custom_llm_provider="huggingface", api_base=api_base, api_key="...")
``` ```

View file

@ -38,7 +38,7 @@ class HuggingfaceRestAPILLM:
self, self,
model: str, model: str,
messages: list, messages: list,
custom_api_base: str, api_base: str,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
optional_params=None, optional_params=None,
@ -48,8 +48,8 @@ class HuggingfaceRestAPILLM:
completion_url: str = "" completion_url: str = ""
if "https" in model: if "https" in model:
completion_url = model completion_url = model
elif custom_api_base: elif api_base:
completion_url = custom_api_base completion_url = api_base
elif "HF_API_BASE" in os.environ: elif "HF_API_BASE" in os.environ:
completion_url = os.getenv("HF_API_BASE", "") completion_url = os.getenv("HF_API_BASE", "")
else: else:

View file

@ -92,7 +92,7 @@ def completion(
verbose=False, verbose=False,
azure=False, azure=False,
custom_llm_provider=None, custom_llm_provider=None,
custom_api_base=None, api_base=None,
litellm_call_id=None, litellm_call_id=None,
litellm_logging_obj=None, litellm_logging_obj=None,
use_client=False, use_client=False,
@ -153,7 +153,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
verbose=verbose, verbose=verbose,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
custom_api_base=custom_api_base, api_base=api_base,
litellm_call_id=litellm_call_id, litellm_call_id=litellm_call_id,
model_alias_map=litellm.model_alias_map, model_alias_map=litellm.model_alias_map,
completion_call_id=id completion_call_id=id
@ -223,7 +223,7 @@ def completion(
# note: if a user sets a custom base - we should ensure this works # note: if a user sets a custom base - we should ensure this works
# allow for the setting of dynamic and stateful api-bases # allow for the setting of dynamic and stateful api-bases
api_base = ( api_base = (
custom_api_base api_base
or litellm.api_base or litellm.api_base
or get_secret("OPENAI_API_BASE") or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1" or "https://api.openai.com/v1"
@ -567,7 +567,7 @@ def completion(
model_response = huggingface_client.completion( model_response = huggingface_client.completion(
model=model, model=model,
messages=messages, messages=messages,
custom_api_base=custom_api_base, api_base=api_base,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
@ -692,7 +692,7 @@ def completion(
response = model_response response = model_response
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
endpoint = ( endpoint = (
litellm.api_base if litellm.api_base is not None else custom_api_base litellm.api_base if litellm.api_base is not None else api_base
) )
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])

View file

@ -31,9 +31,9 @@ def testing_batch_completion(*args, **kwargs):
if isinstance(model, dict) and "custom_llm_provider" in model if isinstance(model, dict) and "custom_llm_provider" in model
else None else None
) )
kwargs_modified["custom_api_base"] = ( kwargs_modified["api_base"] = (
model["custom_api_base"] model["api_base"]
if isinstance(model, dict) and "custom_api_base" in model if isinstance(model, dict) and "api_base" in model
else None else None
) )
for message_list in batch_messages: for message_list in batch_messages:

View file

@ -20,14 +20,14 @@ models = ["gorilla-7b-hf-v1", "gpt-4"]
custom_llm_provider = None custom_llm_provider = None
messages = [{"role": "user", "content": "Hey, how's it going?"}] messages = [{"role": "user", "content": "Hey, how's it going?"}]
for model in models: # iterate through list for model in models: # iterate through list
custom_api_base = None api_base = None
if model == "gorilla-7b-hf-v1": if model == "gorilla-7b-hf-v1":
custom_llm_provider = "custom_openai" custom_llm_provider = "custom_openai"
custom_api_base = "http://zanino.millennium.berkeley.edu:8000/v1" api_base = "http://zanino.millennium.berkeley.edu:8000/v1"
completion( completion(
model=model, model=model,
messages=messages, messages=messages,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
custom_api_base=custom_api_base, api_base=api_base,
logger_fn=logging_fn, logger_fn=logging_fn,
) )

View file

@ -24,7 +24,7 @@
# def test_completion_ollama(): # def test_completion_ollama():
# try: # try:
# response = completion(model="llama2", messages=messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama") # response = completion(model="llama2", messages=messages, api_base="http://localhost:11434", custom_llm_provider="ollama")
# print(response) # print(response)
# string_response = asyncio.run(get_response(response)) # string_response = asyncio.run(get_response(response))
# print(string_response) # print(string_response)
@ -36,7 +36,7 @@
# def test_completion_ollama_stream(): # def test_completion_ollama_stream():
# try: # try:
# response = completion(model="llama2", messages=messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True) # response = completion(model="llama2", messages=messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)
# print(response) # print(response)
# string_response = asyncio.run(get_response(response)) # string_response = asyncio.run(get_response(response))
# print(string_response) # print(string_response)

View file

@ -658,7 +658,7 @@ def get_litellm_params(
replicate=False, replicate=False,
together_ai=False, together_ai=False,
custom_llm_provider=None, custom_llm_provider=None,
custom_api_base=None, api_base=None,
litellm_call_id=None, litellm_call_id=None,
model_alias_map=None, model_alias_map=None,
completion_call_id=None completion_call_id=None
@ -670,7 +670,7 @@ def get_litellm_params(
"logger_fn": logger_fn, "logger_fn": logger_fn,
"verbose": verbose, "verbose": verbose,
"custom_llm_provider": custom_llm_provider, "custom_llm_provider": custom_llm_provider,
"custom_api_base": custom_api_base, "api_base": api_base,
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
"model_alias_map": model_alias_map, "model_alias_map": model_alias_map,
"completion_call_id": completion_call_id, "completion_call_id": completion_call_id,
@ -834,7 +834,7 @@ def get_optional_params( # use the openai defaults
def load_test_model( def load_test_model(
model: str, model: str,
custom_llm_provider: str = "", custom_llm_provider: str = "",
custom_api_base: str = "", api_base: str = "",
prompt: str = "", prompt: str = "",
num_calls: int = 0, num_calls: int = 0,
force_timeout: int = 0, force_timeout: int = 0,
@ -852,7 +852,7 @@ def load_test_model(
model=model, model=model,
messages=messages, messages=messages,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
custom_api_base=custom_api_base, api_base=api_base,
force_timeout=force_timeout, force_timeout=force_timeout,
) )
end_time = time.time() end_time = time.time()