(feat) cloudflare ai workers - add completion support

This commit is contained in:
ishaan-jaff 2023-12-29 11:31:27 +05:30
parent 6f2734100f
commit 8fcfb7df22
2 changed files with 26 additions and 35 deletions

View file

@ -84,9 +84,7 @@ def completion(
## Load Config ## Load Config
config = litellm.CloudflareConfig.get_config() config = litellm.CloudflareConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if ( if k not in optional_params:
k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}") print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
@ -101,45 +99,40 @@ def completion(
eos_token=model_prompt_details.get("eos_token", ""), eos_token=model_prompt_details.get("eos_token", ""),
messages=messages, messages=messages,
) )
else:
prompt = prompt_factory( # cloudflare adds the model to the api base
model=model, api_base = api_base + model
messages=messages,
api_key=api_key,
custom_llm_provider="together_ai",
) # api key required to query together ai model list
data = { data = {
"model": model, "messages": messages,
"prompt": prompt,
"request_type": "language-model-inference",
**optional_params, **optional_params,
} }
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=messages,
api_key=api_key, api_key=api_key,
additional_args={ additional_args={
"complete_input_dict": data,
"headers": headers, "headers": headers,
"api_base": api_base, "api_base": api_base,
"complete_input_dict": data,
}, },
) )
## COMPLETION CALL ## COMPLETION CALL
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post( response = requests.post(
api_base, api_base,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
stream=optional_params["stream_tokens"], stream=optional_params["stream"],
) )
return response.iter_lines() return response.iter_lines()
else: else:
response = requests.post(api_base, headers=headers, data=json.dumps(data)) response = requests.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=messages,
api_key=api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
@ -152,25 +145,21 @@ def completion(
) )
completion_response = response.json() completion_response = response.json()
if len(completion_response["output"]["choices"][0]["text"]) >= 0: model_response["choices"][0]["message"]["content"] = completion_response[
model_response["choices"][0]["message"]["content"] = completion_response[ "result"
"output" ]["response"]
]["choices"][0]["text"]
## CALCULATING USAGE ## CALCULATING USAGE
print_verbose( print_verbose(
f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}" f"CALCULATING CLOUDFLARE TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}"
) )
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = litellm.utils.get_token_count(messages=messages, model=model)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
if "finish_reason" in completion_response["output"]["choices"][0]:
model_response.choices[0].finish_reason = completion_response["output"][
"choices"
][0]["finish_reason"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "together_ai/" + model model_response["model"] = "cloudflare/" + model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -1572,12 +1572,14 @@ def completion(
or litellm.api_key or litellm.api_key
or get_secret("CLOUDFLARE_API_KEY") or get_secret("CLOUDFLARE_API_KEY")
) )
# api_base = ( account_id = get_secret("CLOUDFLARE_ACCOUNT_ID")
# api_base api_base = (
# or litellm.api_base api_base
# or get_secret("CLOUDFLARE_API_BASE") or litellm.api_base
# or "https://api.anthropic.com/v1/complete" or get_secret("CLOUDFLARE_API_BASE")
# ) or f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/"
)
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = cloudflare.completion( response = cloudflare.completion(
model=model, model=model,