mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
petals fixes
This commit is contained in:
parent
f6ccadabc8
commit
385640b743
4 changed files with 9 additions and 6 deletions
|
@ -276,7 +276,7 @@ provider_list: List = [
|
||||||
"vllm",
|
"vllm",
|
||||||
"nlp_cloud",
|
"nlp_cloud",
|
||||||
"bedrock",
|
"bedrock",
|
||||||
"petals,"
|
"petals",
|
||||||
"custom", # custom apis
|
"custom", # custom apis
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -32,8 +32,6 @@ def completion(
|
||||||
|
|
||||||
model = model
|
model = model
|
||||||
|
|
||||||
# You could also use "meta-llama/Llama-2-70b-chat-hf" or any other supported model from 🤗 Model Hub
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, add_bos_token=False)
|
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, add_bos_token=False)
|
||||||
model = AutoDistributedModelForCausalLM.from_pretrained(model)
|
model = AutoDistributedModelForCausalLM.from_pretrained(model)
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
@ -76,6 +74,8 @@ def completion(
|
||||||
print_verbose(f"raw model_response: {outputs}")
|
print_verbose(f"raw model_response: {outputs}")
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
output_text = tokenizer.decode(outputs[0])
|
output_text = tokenizer.decode(outputs[0])
|
||||||
|
print("output text")
|
||||||
|
print(output_text)
|
||||||
model_response["choices"][0]["message"]["content"] = output_text
|
model_response["choices"][0]["message"]["content"] = output_text
|
||||||
|
|
||||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||||
|
|
|
@ -959,8 +959,9 @@ def completion(
|
||||||
or custom_llm_provider == "petals-team"
|
or custom_llm_provider == "petals-team"
|
||||||
or model in litellm.petals_models
|
or model in litellm.petals_models
|
||||||
):
|
):
|
||||||
custom_llm_provider = "baseten"
|
custom_llm_provider = "petals"
|
||||||
|
print("model on petals")
|
||||||
|
print(model)
|
||||||
model_response = petals.completion(
|
model_response = petals.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -970,7 +971,6 @@ def completion(
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
api_key=baseten_key,
|
|
||||||
logging_obj=logging
|
logging_obj=logging
|
||||||
)
|
)
|
||||||
if inspect.isgenerator(model_response) or (stream == True):
|
if inspect.isgenerator(model_response) or (stream == True):
|
||||||
|
|
|
@ -1122,6 +1122,9 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
|
||||||
## nlp_cloud
|
## nlp_cloud
|
||||||
elif model in litellm.nlp_cloud_models:
|
elif model in litellm.nlp_cloud_models:
|
||||||
custom_llm_provider = "nlp_cloud"
|
custom_llm_provider = "nlp_cloud"
|
||||||
|
## petals
|
||||||
|
elif model in litellm.petals_models:
|
||||||
|
custom_llm_provider = "petals"
|
||||||
|
|
||||||
if custom_llm_provider is None or custom_llm_provider=="":
|
if custom_llm_provider is None or custom_llm_provider=="":
|
||||||
raise ValueError(f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/{model}',..)` Learn more: https://docs.litellm.ai/docs/providers")
|
raise ValueError(f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/{model}',..)` Learn more: https://docs.litellm.ai/docs/providers")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue