allow non tgi llms

This commit is contained in:
ishaan-jaff 2023-09-18 10:26:55 -07:00
parent e83d89d12f
commit e7f4e8b4a4

View file

@ -102,6 +102,30 @@ def completion(
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
}
input_text = prompt
elif task == "other":
print("task=other, custom api base")
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
)
else:
prompt = prompt_factory(model=model, messages=messages)
inference_params = copy.deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
print("inf params")
print(inference_params)
data = {
"inputs": prompt,
"parameters": inference_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
}
input_text = prompt
## LOGGING
logging_obj.pre_call(
input=input_text,
@ -161,6 +185,8 @@ def completion(
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
elif task == "other":
model_response["choices"][0]["message"]["content"] = str(completion_response[0]["generated_text"])
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(input_text)