add support for keys

This commit is contained in:
ishaan-jaff 2023-08-17 15:30:40 -07:00
parent 3f5e47e3ce
commit e72836d181
10 changed files with 13 additions and 141 deletions

View file

@ -187,7 +187,7 @@ def completion(
}
response = model_response
elif model in litellm.anthropic_models:
anthropic_key = api_key if api_key is not None else litellm.anthropic_key
anthropic_key = api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY")
anthropic_client = AnthropicLLM(encoding=encoding, default_max_tokens_to_sample=litellm.max_tokens, api_key=anthropic_key)
model_response = anthropic_client.completion(model=model, messages=messages, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn)
if 'stream' in optional_params and optional_params['stream'] == True:
@ -241,12 +241,7 @@ def completion(
# import cohere/if it fails then pip install cohere
install_and_import("cohere")
import cohere
if api_key:
cohere_key = api_key
elif litellm.cohere_key:
cohere_key = litellm.cohere_key
else:
cohere_key = get_secret("COHERE_API_KEY")
cohere_key = api_key or litellm.cohere_key or get_secret("COHERE_API_KEY") or get_secret("CO_API_KEY")
co = cohere.Client(cohere_key)
prompt = " ".join([message["content"] for message in messages])
## LOGGING
@ -279,7 +274,7 @@ def completion(
response = model_response
elif model in litellm.huggingface_models or custom_llm_provider == "huggingface":
custom_llm_provider = "huggingface"
huggingface_key = api_key if api_key is not None else litellm.huggingface_key
huggingface_key = api_key or litellm.huggingface_key or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_API_KEY")
huggingface_client = HuggingfaceRestAPILLM(encoding=encoding, api_key=huggingface_key)
model_response = huggingface_client.completion(model=model, messages=messages, custom_api_base=custom_api_base, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn)
if 'stream' in optional_params and optional_params['stream'] == True: