custom api base

This commit is contained in:
ishaan-jaff 2023-09-05 16:50:32 -07:00
parent 262bb07ade
commit 8ecef03f63
3 changed files with 78 additions and 1 deletions

View file

@ -191,6 +191,7 @@ provider_list = [
"azure",
"sagemaker",
"bedrock",
"custom", # custom apis
]
models_by_provider = {

View file

@ -546,7 +546,11 @@ def completion(
}
response = model_response
elif (
model in litellm.huggingface_models or custom_llm_provider == "huggingface"
(
model in litellm.huggingface_models and
custom_llm_provider!="custom" # if users use a hf model, with a custom/provider. See implementation of custom_llm_provider == custom
) or
custom_llm_provider == "huggingface"
):
custom_llm_provider = "huggingface"
huggingface_key = (
@ -783,6 +787,62 @@ def completion(
)
return response
response = model_response
elif (
custom_llm_provider == "custom"
):
import requests
url = (
litellm.api_base or
api_base
)
"""
assume input to custom LLM api bases follow this format:
resp = requests.post(
api_base,
json={
'model': 'meta-llama/Llama-2-13b-hf', # model name
'params': {
'prompt': ["The capital of France is P"],
'max_tokens': 32,
'temperature': 0.7,
'top_p': 1.0,
'top_k': 40,
}
}
)
"""
prompt = " ".join([message["content"] for message in messages])
resp = requests.post(url, json={
'model': model,
'params': {
'prompt': [prompt],
'max_tokens': max_tokens,
'temperature': temperature,
'top_p': top_p,
'top_k': top_k,
}
})
resp = resp.json()
"""
assume all responses from custom api_bases of this format:
{
'data': [
{
'prompt': 'The capital of France is P',
'output': ['The capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France'],
'params': {'temperature': 0.7, 'top_k': 40, 'top_p': 1}}],
'message': 'ok'
}
]
}
"""
string_response = resp['data'][0]['output'][0]
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = string_response
model_response["created"] = time.time()
model_response["model"] = model
response = model_response
else:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"

View file

@ -435,6 +435,22 @@ def test_completion_sagemaker():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_custom_api_base():
try:
response = completion(
model="custom/meta-llama/Llama-2-13b-hf",
messages=messages,
temperature=0.2,
max_tokens=10,
api_base="https://api.autoai.dev/inference",
)
# Add any assertions here to check the response
print("got response\n", response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_custom_api_base()
# def test_vertex_ai():
# model_name = "chat-bison"
# try: