Testing CLI Model Inference Workflows in FastChat
This test suite validates the command line interface functionality for model inference in FastChat, focusing on different GPU configurations and model loading scenarios. It ensures reliable model loading and inference across various popular language models.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
lm-sys/fastchat
tests/test_cli.py
"""Test command line interface for model inference."""
import argparse
import os
from fastchat.utils import run_cmd
def test_single_gpu():
models = [
"lmsys/vicuna-7b-v1.5",
"lmsys/longchat-7b-16k",
"lmsys/fastchat-t5-3b-v1.0",
"meta-llama/Llama-2-7b-chat-hf",
"THUDM/chatglm-6b",
"THUDM/chatglm2-6b",
"mosaicml/mpt-7b-chat",
"tiiuae/falcon-7b-instruct",
"~/model_weights/alpaca-7b",
"~/model_weights/RWKV-4-Raven-7B-v11x-Eng99%-Other1%-20230429-ctx8192.pth",
]
for model_path in models:
if "model_weights" in model_path and not os.path.exists(
os.path.expanduser(model_path)
):
continue
cmd = (
f"python3 -m fastchat.serve.cli --model-path {model_path} "
f"--style programmatic < test_cli_inputs.txt"
)
ret = run_cmd(cmd)
if ret != 0:
return
print("")
def test_multi_gpu():
models = [
"lmsys/vicuna-13b-v1.3",
]
for model_path in models:
cmd = (
f"python3 -m fastchat.serve.cli --model-path {model_path} "
f"--style programmatic --num-gpus 2 --max-gpu-memory 14Gib < test_cli_inputs.txt"
)
ret = run_cmd(cmd)
if ret != 0:
return
print("")
def test_8bit():
models = [
"lmsys/vicuna-13b-v1.3",
]
for model_path in models:
cmd = (
f"python3 -m fastchat.serve.cli --model-path {model_path} "
f"--style programmatic --load-8bit < test_cli_inputs.txt"
)
ret = run_cmd(cmd)
if ret != 0:
return
print("")
def test_hf_api():
models = [
"lmsys/vicuna-7b-v1.5",
"lmsys/fastchat-t5-3b-v1.0",
]
for model_path in models:
cmd = f"python3 -m fastchat.serve.huggingface_api --model-path {model_path}"
ret = run_cmd(cmd)
if ret != 0:
return
print("")
if __name__ == "__main__":
test_single_gpu()
test_multi_gpu()
test_8bit()
test_hf_api()