fix use provider specific routing

This commit is contained in:
Ishaan Jaff 2024-08-07 14:37:20 -07:00
parent 218ba0f470
commit f1ffa82062
4 changed files with 35 additions and 15 deletions

View file

@ -8,9 +8,15 @@ model_list:
litellm_params:
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
api_key: "os.environ/FIREWORKS"
- model_name: "*"
# provider specific wildcard routing
- model_name: "anthropic/*"
litellm_params:
model: "*"
model: "anthropic/*"
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: "groq/*"
litellm_params:
model: "groq/*"
api_key: os.environ/GROQ_API_KEY
- model_name: "*"
litellm_params:
model: openai/*

View file

@ -4469,13 +4469,7 @@ class Router:
)
model = self.model_group_alias[model]
if model not in self.model_names and self.default_deployment is not None:
updated_deployment = copy.deepcopy(
self.default_deployment
) # self.default_deployment
updated_deployment["litellm_params"]["model"] = model
return model, updated_deployment
elif model not in self.model_names:
if model not in self.model_names:
# check if provider/ specific wildcard routing
try:
(
@ -4499,6 +4493,14 @@ class Router:
# get_llm_provider raises exception when provider is unknown
pass
# check if default deployment is set
if self.default_deployment is not None:
updated_deployment = copy.deepcopy(
self.default_deployment
) # self.default_deployment
updated_deployment["litellm_params"]["model"] = model
return model, updated_deployment
## get healthy deployments
### get all deployments
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]

View file

@ -86,12 +86,16 @@ model_list:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
# Pass through all llm requests to litellm.completion/litellm.embedding
# if user passes model="anthropic/claude-3-opus-20240229" proxy will make requests to anthropic claude-3-opus-20240229 using ANTHROPIC_API_KEY
- model_name: "*"
# provider specific wildcard routing
- model_name: "anthropic/*"
litellm_params:
model: "*"
model: "anthropic/*"
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: "groq/*"
litellm_params:
model: "groq/*"
api_key: os.environ/GROQ_API_KEY
- model_name: mistral-embed
litellm_params:
model: mistral/mistral-embed

View file

@ -119,7 +119,9 @@ async def chat_completion(session, key, model: Union[str, List] = "gpt-4"):
print()
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
raise Exception(
f"Request did not return a 200 status code: {status}, response text={response_text}"
)
response_header_check(
response
@ -485,6 +487,12 @@ async def test_proxy_all_models():
session=session, key=LITELLM_MASTER_KEY, model="groq/llama3-8b-8192"
)
await chat_completion(
session=session,
key=LITELLM_MASTER_KEY,
model="anthropic/claude-3-sonnet-20240229",
)
@pytest.mark.asyncio
async def test_batch_chat_completions():