Testing Image Processing Utilities in FastChat
This test suite validates image utility functions in FastChat, focusing on image resizing, moderation, and conversation template handling. It ensures proper image size management and format conversion across different chat model requirements.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
lm-sys/fastchat
tests/test_image_utils.py
"""
Usage:
python3 -m unittest tests.test_image_utils
"""
import base64
from io import BytesIO
import os
import unittest
import numpy as np
from PIL import Image
from fastchat.utils import (
resize_image_and_return_image_in_bytes,
image_moderation_filter,
)
from fastchat.conversation import get_conv_template
def check_byte_size_in_mb(image_base64_str):
return len(image_base64_str) / 1024 / 1024
def generate_random_image(target_size_mb, image_format="PNG"):
# Convert target size from MB to bytes
target_size_bytes = target_size_mb * 1024 * 1024
# Estimate dimensions
dimension = int((target_size_bytes / 3) ** 0.5)
# Generate random pixel data
pixel_data = np.random.randint(0, 256, (dimension, dimension, 3), dtype=np.uint8)
# Create an image from the pixel data
img = Image.fromarray(pixel_data)
# Save image to a temporary file
temp_filename = "temp_image." + image_format.lower()
img.save(temp_filename, format=image_format)
# Check the file size and adjust quality if needed
while os.path.getsize(temp_filename) < target_size_bytes:
# Increase dimensions or change compression quality
dimension += 1
pixel_data = np.random.randint(
0, 256, (dimension, dimension, 3), dtype=np.uint8
)
img = Image.fromarray(pixel_data)
img.save(temp_filename, format=image_format)
return img
class DontResizeIfLessThanMaxTest(unittest.TestCase):
def test_dont_resize_if_less_than_max(self):
max_image_size = 5
initial_size_mb = 0.1 # Initial image size
img = generate_random_image(initial_size_mb)
image_bytes = BytesIO()
img.save(image_bytes, format="PNG") # Save the image as JPEG
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())
image_bytes = resize_image_and_return_image_in_bytes(
img, max_image_size_mb=max_image_size
)
new_image_size = check_byte_size_in_mb(image_bytes.getvalue())
self.assertEqual(previous_image_size, new_image_size)
class ResizeLargeImageForModerationEndpoint(unittest.TestCase):
def test_resize_large_image_and_send_to_moderation_filter(self):
initial_size_mb = 6 # Initial image size which we know is greater than what the endpoint can take
img = generate_random_image(initial_size_mb)
nsfw_flag, csam_flag = image_moderation_filter(img)
self.assertFalse(nsfw_flag)
self.assertFalse(nsfw_flag)
class DontResizeIfMaxImageSizeIsNone(unittest.TestCase):
def test_dont_resize_if_max_image_size_is_none(self):
initial_size_mb = 0.2 # Initial image size
img = generate_random_image(initial_size_mb)
image_bytes = BytesIO()
img.save(image_bytes, format="PNG") # Save the image as JPEG
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())
image_bytes = resize_image_and_return_image_in_bytes(
img, max_image_size_mb=None
)
new_image_size = check_byte_size_in_mb(image_bytes.getvalue())
self.assertEqual(previous_image_size, new_image_size)
class OpenAIConversationDontResizeImage(unittest.TestCase):
def test(self):
conv = get_conv_template("chatgpt")
initial_size_mb = 0.2 # Initial image size
img = generate_random_image(initial_size_mb)
image_bytes = BytesIO()
img.save(image_bytes, format="PNG") # Save the image as JPEG
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())
resized_img = conv.convert_image_to_base64(img)
resized_img_bytes = base64.b64decode(resized_img)
new_image_size = check_byte_size_in_mb(resized_img_bytes)
self.assertEqual(previous_image_size, new_image_size)
class ClaudeConversationResizesCorrectly(unittest.TestCase):
def test(self):
conv = get_conv_template("claude-3-haiku-20240307")
initial_size_mb = 5 # Initial image size
img = generate_random_image(initial_size_mb)
image_bytes = BytesIO()
img.save(image_bytes, format="PNG") # Save the image as JPEG
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())
resized_img = conv.convert_image_to_base64(img)
new_base64_image_size = check_byte_size_in_mb(resized_img)
new_image_bytes_size = check_byte_size_in_mb(base64.b64decode(resized_img))
self.assertLess(new_image_bytes_size, previous_image_size)
self.assertLessEqual(new_image_bytes_size, conv.max_image_size_mb)
self.assertLessEqual(new_base64_image_size, 5)