(feat) cloudflare ai workers - add completion support

This commit is contained in:
ishaan-jaff 2023-12-29 11:31:27 +05:30
parent 6f2734100f
commit 8fcfb7df22
2 changed files with 26 additions and 35 deletions

View file

@ -84,9 +84,7 @@ def completion(
## Load Config
config = litellm.CloudflareConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
if k not in optional_params:
optional_params[k] = v
print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
@ -101,45 +99,40 @@ def completion(
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(
model=model,
messages=messages,
api_key=api_key,
custom_llm_provider="together_ai",
) # api key required to query together ai model list
# cloudflare adds the model to the api base
api_base = api_base + model
data = {
"model": model,
"prompt": prompt,
"request_type": "language-model-inference",
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
"complete_input_dict": data,
},
)
## COMPLETION CALL
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream_tokens"],
stream=optional_params["stream"],
)
return response.iter_lines()
else:
response = requests.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=prompt,
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
@ -152,25 +145,21 @@ def completion(
)
completion_response = response.json()
if len(completion_response["output"]["choices"][0]["text"]) >= 0:
model_response["choices"][0]["message"]["content"] = completion_response[
"output"
]["choices"][0]["text"]
model_response["choices"][0]["message"]["content"] = completion_response[
"result"
]["response"]
## CALCULATING USAGE
print_verbose(
f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}"
f"CALCULATING CLOUDFLARE TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}"
)
prompt_tokens = len(encoding.encode(prompt))
prompt_tokens = litellm.utils.get_token_count(messages=messages, model=model)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
if "finish_reason" in completion_response["output"]["choices"][0]:
model_response.choices[0].finish_reason = completion_response["output"][
"choices"
][0]["finish_reason"]
model_response["created"] = int(time.time())
model_response["model"] = "together_ai/" + model
model_response["model"] = "cloudflare/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,

View file

@ -1572,12 +1572,14 @@ def completion(
or litellm.api_key
or get_secret("CLOUDFLARE_API_KEY")
)
# api_base = (
# api_base
# or litellm.api_base
# or get_secret("CLOUDFLARE_API_BASE")
# or "https://api.anthropic.com/v1/complete"
# )
account_id = get_secret("CLOUDFLARE_ACCOUNT_ID")
api_base = (
api_base
or litellm.api_base
or get_secret("CLOUDFLARE_API_BASE")
or f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/"
)
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = cloudflare.completion(
model=model,