rebase and fix some small breakage due to model -> model_id fix

This commit is contained in:
Ashwin Bharambe 2024-11-12 21:47:39 -08:00
parent 22aedd0277
commit 1cb42d3060
6 changed files with 20 additions and 11 deletions

View file

@ -38,15 +38,15 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(obj)
return await p.register_shield(obj)
elif api == Api.memory:
await p.register_memory_bank(obj)
return await p.register_memory_bank(obj)
elif api == Api.datasetio:
await p.register_dataset(obj)
return await p.register_dataset(obj)
elif api == Api.scoring:
await p.register_scoring_function(obj)
return await p.register_scoring_function(obj)
elif api == Api.eval:
await p.register_eval_task(obj)
return await p.register_eval_task(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")

View file

@ -234,7 +234,7 @@ class LlamaGuardShield:
# TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
async for chunk in await self.inference_api.chat_completion(
model=self.model,
model_id=self.model,
messages=[shield_input_message],
stream=True,
):

View file

@ -164,7 +164,6 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
print(f"model={model}")
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,

View file

@ -46,6 +46,16 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="together",
marks=pytest.mark.together,
),
pytest.param(
{
"inference": "fireworks",
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
},
id="fireworks",
marks=pytest.mark.fireworks,
),
pytest.param(
{
"inference": "remote",
@ -60,7 +70,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
def pytest_configure(config):
for mark in ["meta_reference", "ollama", "together", "remote"]:
for mark in ["meta_reference", "ollama", "together", "fireworks", "remote"]:
config.addinivalue_line(
"markers",
f"{mark}: marks tests as {mark} specific",

View file

@ -147,9 +147,9 @@ class TestInference:
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion(
model_id=inference_model,
content=user_input,
stream=False,
model=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),

View file

@ -55,7 +55,7 @@ class TestVisionModelInference:
)
response = await inference_impl.chat_completion(
model=inference_model,
model_id=inference_model,
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(content=[image, "Describe this image in two sentences."]),
@ -102,7 +102,7 @@ class TestVisionModelInference:
response = [
r
async for r in await inference_impl.chat_completion(
model=inference_model,
model_id=inference_model,
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(