petals fixes

This commit is contained in:
ishaan-jaff 2023-09-19 09:05:12 -07:00
parent f6ccadabc8
commit 385640b743
4 changed files with 9 additions and 6 deletions

View file

@ -276,7 +276,7 @@ provider_list: List = [
"vllm", "vllm",
"nlp_cloud", "nlp_cloud",
"bedrock", "bedrock",
"petals," "petals",
"custom", # custom apis "custom", # custom apis
] ]

View file

@ -32,8 +32,6 @@ def completion(
model = model model = model
# You could also use "meta-llama/Llama-2-70b-chat-hf" or any other supported model from 🤗 Model Hub
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, add_bos_token=False) tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, add_bos_token=False)
model = AutoDistributedModelForCausalLM.from_pretrained(model) model = AutoDistributedModelForCausalLM.from_pretrained(model)
model = model.cuda() model = model.cuda()
@ -76,6 +74,8 @@ def completion(
print_verbose(f"raw model_response: {outputs}") print_verbose(f"raw model_response: {outputs}")
## RESPONSE OBJECT ## RESPONSE OBJECT
output_text = tokenizer.decode(outputs[0]) output_text = tokenizer.decode(outputs[0])
print("output text")
print(output_text)
model_response["choices"][0]["message"]["content"] = output_text model_response["choices"][0]["message"]["content"] = output_text
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.

View file

@ -959,8 +959,9 @@ def completion(
or custom_llm_provider == "petals-team" or custom_llm_provider == "petals-team"
or model in litellm.petals_models or model in litellm.petals_models
): ):
custom_llm_provider = "baseten" custom_llm_provider = "petals"
print("model on petals")
print(model)
model_response = petals.completion( model_response = petals.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -970,7 +971,6 @@ def completion(
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
api_key=baseten_key,
logging_obj=logging logging_obj=logging
) )
if inspect.isgenerator(model_response) or (stream == True): if inspect.isgenerator(model_response) or (stream == True):

View file

@ -1122,6 +1122,9 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
## nlp_cloud ## nlp_cloud
elif model in litellm.nlp_cloud_models: elif model in litellm.nlp_cloud_models:
custom_llm_provider = "nlp_cloud" custom_llm_provider = "nlp_cloud"
## petals
elif model in litellm.petals_models:
custom_llm_provider = "petals"
if custom_llm_provider is None or custom_llm_provider=="": if custom_llm_provider is None or custom_llm_provider=="":
raise ValueError(f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/{model}',..)` Learn more: https://docs.litellm.ai/docs/providers") raise ValueError(f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/{model}',..)` Learn more: https://docs.litellm.ai/docs/providers")