Back to Repositories

Testing FastChat Message Generation Pipeline in lm-sys/FastChat

This test module implements functionality for sending test messages to FastChat language models, validating the message generation pipeline and model response handling. It provides a robust framework for testing model interactions and streaming responses.

Test Coverage Overview

The test suite covers essential messaging functionality for FastChat models, including:

  • Model worker discovery and availability verification
  • Conversation template handling and prompt construction
  • Streaming response generation and processing
  • Parameter configuration and validation
  • Error handling for unavailable workers

Implementation Analysis

The testing approach utilizes a command-line interface to facilitate model interaction testing. It implements REST API calls to the controller and worker services, handling both direct worker addressing and dynamic worker discovery through the controller. The implementation features stream processing for real-time response handling.

Key technical patterns include request parameter validation, conversation state management, and chunked response processing.

Technical Details

  • Uses requests library for HTTP communication
  • Implements argparse for CLI parameter handling
  • Employs JSON for data serialization
  • Utilizes conversation templates for prompt formatting
  • Implements streaming response handling with byte delimiter

Best Practices Demonstrated

The test implementation showcases several testing best practices:

  • Clear separation of concerns between worker discovery and message generation
  • Robust parameter handling with sensible defaults
  • Proper error handling and user feedback
  • Efficient streaming response processing
  • Modular design for easy maintenance and extension

lm-sys/fastchat

fastchat/serve/test_message.py

            
"""Send a test message."""
import argparse
import json

import requests

from fastchat.model.model_adapter import get_conversation_template


def main():
    model_name = args.model_name

    if args.worker_address:
        worker_addr = args.worker_address
    else:
        controller_addr = args.controller_address
        ret = requests.post(controller_addr + "/refresh_all_workers")
        ret = requests.post(controller_addr + "/list_models")
        models = ret.json()["models"]
        models.sort()
        print(f"Models: {models}")

        ret = requests.post(
            controller_addr + "/get_worker_address", json={"model": model_name}
        )
        worker_addr = ret.json()["address"]
        print(f"worker_addr: {worker_addr}")

    if worker_addr == "":
        print(f"No available workers for {model_name}")
        return

    conv = get_conversation_template(model_name)
    conv.append_message(conv.roles[0], args.message)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    headers = {"User-Agent": "FastChat Client"}
    gen_params = {
        "model": model_name,
        "prompt": prompt,
        "temperature": args.temperature,
        "max_new_tokens": args.max_new_tokens,
        "stop": conv.stop_str,
        "stop_token_ids": conv.stop_token_ids,
        "echo": False,
    }
    response = requests.post(
        worker_addr + "/worker_generate_stream",
        headers=headers,
        json=gen_params,
        stream=True,
    )

    print(f"{conv.roles[0]}: {args.message}")
    print(f"{conv.roles[1]}: ", end="")
    prev = 0
    for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
        if chunk:
            data = json.loads(chunk.decode())
            output = data["text"].strip()
            print(output[prev:], end="", flush=True)
            prev = len(output)
    print("")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--controller-address", type=str, default="http://localhost:21001"
    )
    parser.add_argument("--worker-address", type=str)
    parser.add_argument("--model-name", type=str, required=True)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--max-new-tokens", type=int, default=32)
    parser.add_argument(
        "--message", type=str, default="Tell me a story with more than 1000 words."
    )
    args = parser.parse_args()

    main()