Back to Repositories

Testing Multi-Task Learning Components in HanLP

This test suite validates the multi-task learning (MTL) capabilities of HanLP, focusing on tokenization, part-of-speech tagging, and semantic dependency parsing. It ensures robust handling of Chinese text processing across different scenarios and configurations.

Test Coverage Overview

The test suite provides comprehensive coverage of HanLP’s MTL functionality, including:

  • Single and multiple sentence processing
  • Empty string handling
  • Pre-tokenized text processing
  • Semantic dependency parsing
  • Unicode and emoji handling
  • Custom dictionary combinations

Implementation Analysis

The testing approach utilizes Python’s unittest framework with a focus on modular test methods. Each test case isolates specific functionality, employing assertions to validate expected outputs against actual results. The implementation leverages HanLP’s pretrained models and custom configurations for thorough validation.

Technical Details

Key technical components include:

  • HanLP pretrained ELECTRA model integration
  • Multiprocessing support testing
  • Custom dictionary configuration
  • Span output validation
  • Unicode character handling

Best Practices Demonstrated

The test suite exemplifies several testing best practices:

  • Isolated test cases for specific functionality
  • Comprehensive edge case handling
  • Clear test method naming conventions
  • Proper test setup and teardown
  • Efficient use of assertions for validation

hankcs/hanlp

tests/test_mtl.py

            
import hanlp
import unittest
from multiprocessing.dummy import Pool
from hanlp_common.document import Document

mtl = hanlp.load(hanlp.pretrained.mtl.CLOSE_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_SMALL_ZH, devices=-1)


def tokenize(mtl, text):
    return mtl(text, tasks='tok/fine')['tok/fine']


class TestMultiTaskLearning(unittest.TestCase):
    def test_mtl_single_sent(self):
        doc: Document = mtl('商品和服务')
        self.assertSequenceEqual(doc['tok/fine'], ["商品", "和", "服务"])

    def test_mtl_multiple_sents(self):
        doc: Document = mtl(['商品和服务', '研究生命'])
        self.assertSequenceEqual(doc['tok/fine'], [
            ["商品", "和", "服务"],
            ["研究", "生命"]
        ])

    def test_mtl_empty_str(self):
        mtl('')
        mtl(' ')
        mtl([''])
        mtl([' '])
        mtl(['', ' '])
        mtl(['', ' ', 'good'])
        mtl([[]], skip_tasks='tok*')

    def test_skip_tok(self):
        pre_tokenized_sents = [
            ["商品和服务", '一个', '词'],
            ["研究", "生命"]
        ]
        doc: Document = mtl(pre_tokenized_sents, skip_tasks='tok*')
        self.assertSequenceEqual(doc['tok'], pre_tokenized_sents)

    def test_sdp_as_the_first_task(self):
        doc: Document = mtl(['人', '吃', '鱼'], tasks='sdp', skip_tasks='tok*')
        self.assertDictEqual(
            doc.to_dict(),
            {
                "sdp": [
                    [(2, "Agt")],
                    [(0, "Root")],
                    [(2, "Pat")]
                ],
                "tok": [
                    "人",
                    "吃",
                    "鱼"
                ]
            }
        )

    def test_threading(self):
        num_proc = 8
        with Pool(num_proc) as pool:
            results = pool.starmap(tokenize, [(mtl, '商品和服务')] * num_proc)
            self.assertSequenceEqual(results, [['商品', '和', '服务']] * num_proc)

    def test_emoji(self):
        self.assertSequenceEqual(mtl('( ͡° ͜ʖ ͡ °)你好', tasks='tok/fine')['tok/fine'],
                                 ["(", " ͡", "°", " ͜", "ʖ", " ͡ ", "°", ")", "你", "好"])
        mtl['tok/fine'].dict_combine = {'( ͡° ͜ʖ ͡ °)'}
        self.assertSequenceEqual(mtl('( ͡° ͜ʖ ͡ °)你好', tasks='tok/fine')['tok/fine'],
                                 ["( ͡° ͜ʖ ͡ °)", "你", "好"])

    def test_unicode_removed_by_hf(self):
        self.assertSequenceEqual(mtl('͡', tasks='tok/fine')['tok/fine'], ['͡'])

    def test_space(self):
        task = 'tok/fine'
        doc: Document = mtl('商品 和服务', tasks=task)
        self.assertSequenceEqual(doc[task], ["商品", "和", "服务"])
        mtl[task].dict_combine = {('iPad', 'Pro'), '2个空格'}
        self.assertSequenceEqual(mtl("如何评价iPad Pro ?iPad  Pro有2个空格", tasks=task)[task],
                                 ['如何', '评价', 'iPad Pro', '?', 'iPad  Pro', '有', '2个空格'])

    def test_transform(self):
        task = 'tok/fine'
        mtl[task].dict_force = {'用户ID'}
        self.assertSequenceEqual(mtl("我的用户ID跟你的用户id不同", tasks=task)[task],
                                 ['我', '的', '用户ID', '跟', '你', '的', '用户', 'id', '不同'])

    def test_tok_offset(self):
        task = 'tok/fine'
        tok = mtl[task]
        tok.config.output_spans = True
        tok.dict_force = None
        tok.dict_combine = None
        sent = '我先去看医生'

        for t, b, e in mtl(sent, tasks=task)[task]:
            self.assertEqual(t, sent[b:e])

        tok.dict_combine = {'先去'}
        for t, b, e in mtl(sent, tasks=task)[task]:
            self.assertEqual(t, sent[b:e])

        tok.config.output_spans = False
        tok.dict_force = None
        tok.dict_combine = None


if __name__ == '__main__':
    unittest.main()