(fix) use OpenAI organization in ahealth_check

This commit is contained in:
ishaan-jaff 2024-01-30 11:45:22 -08:00
parent 42005ad964
commit 2806a2e99f
3 changed files with 7 additions and 2 deletions

View file

@ -759,8 +759,11 @@ class OpenAIChatCompletion(BaseLLM):
messages: Optional[list] = None, messages: Optional[list] = None,
input: Optional[list] = None, input: Optional[list] = None,
prompt: Optional[str] = None, prompt: Optional[str] = None,
organization: Optional[str] = None,
): ):
client = AsyncOpenAI(api_key=api_key, timeout=timeout) client = AsyncOpenAI(
api_key=api_key, timeout=timeout, organization=organization
)
if model is None and mode != "image_generation": if model is None and mode != "image_generation":
raise Exception("model is not set") raise Exception("model is not set")

View file

@ -3227,6 +3227,7 @@ async def ahealth_check(
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"
): ):
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY") api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
organization = model_params.get("organization")
timeout = ( timeout = (
model_params.get("timeout") model_params.get("timeout")
@ -3244,6 +3245,7 @@ async def ahealth_check(
mode=mode, mode=mode,
prompt=prompt, prompt=prompt,
input=input, input=input,
organization=organization,
) )
else: else:
if mode == "embedding": if mode == "embedding":

View file

@ -1411,7 +1411,7 @@ class Router:
max_retries = litellm.get_secret(max_retries_env_name) max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries litellm_params["max_retries"] = max_retries
organization = litellm_params.pop("organization", None) organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"): if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "") organization_env_name = organization.replace("os.environ/", "")
organization = litellm.get_secret(organization_env_name) organization = litellm.get_secret(organization_env_name)