mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
add baseten support
This commit is contained in:
parent
d7d9893cf3
commit
02a666ee02
3 changed files with 75 additions and 3 deletions
|
@ -431,6 +431,32 @@ def completion(
|
|||
generator = get_ollama_response_stream(endpoint, model, prompt)
|
||||
# assume all responses are streamed
|
||||
return generator
|
||||
elif custom_llm_provider == "baseten" or litellm.api_base=="https://app.baseten.co":
|
||||
install_and_import("baseten")
|
||||
import baseten
|
||||
base_ten_key = get_secret('BASETEN_API_KEY')
|
||||
baseten.login(base_ten_key)
|
||||
|
||||
prompt = " ".join([message["content"] for message in messages])
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
||||
|
||||
base_ten__model = baseten.deployed_model_version_id(model)
|
||||
|
||||
completion_response = base_ten__model.predict({"prompt": prompt})
|
||||
if type(completion_response) == dict:
|
||||
completion_response = completion_response["data"]
|
||||
if type(completion_response) == dict:
|
||||
completion_response = completion_response["generated_text"]
|
||||
|
||||
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = completion_response
|
||||
model_response["created"] = time.time()
|
||||
model_response["model"] = model
|
||||
response = model_response
|
||||
|
||||
elif custom_llm_provider == "petals":
|
||||
install_and_import("transformers")
|
||||
from transformers import AutoTokenizer
|
||||
|
@ -446,7 +472,7 @@ def completion(
|
|||
|
||||
outputs = model.generate(
|
||||
inputs=inputs,
|
||||
temperature=1.0
|
||||
max_new_tokens=5
|
||||
)
|
||||
|
||||
print("got output", outputs)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue