Back to Repositories

Validating ThreadLocal Memory Management in HikariCP

This test suite validates memory leak detection and thread management in HikariCP’s ConcurrentBag implementation, focusing on Tomcat-specific scenarios. It ensures proper cleanup of ThreadLocal resources and verifies the bag’s behavior under various concurrent operations.

Test Coverage Overview

The test suite comprehensively covers memory leak detection in ThreadLocal usage within HikariCP’s ConcurrentBag.

Key areas tested include:
  • ThreadLocal resource cleanup in web application contexts
  • ConcurrentBag operations (borrow, requite, remove)
  • ClassLoader leak detection
  • Thread state management

Implementation Analysis

The testing approach utilizes JUnit with custom ClassLoader implementations to simulate web application environments. It employs reflection techniques to access internal ThreadLocal structures and implements comprehensive thread state verification.

Notable patterns include:
  • Custom ClassLoader injection
  • Reflection-based thread inspection
  • Concurrent operation validation

Technical Details

Testing infrastructure includes:
  • JUnit test framework
  • Custom FauxWebClassLoader implementation
  • Reflection API for thread inspection
  • SLF4J for logging
  • CompletableFuture for async operations

Best Practices Demonstrated

The test suite exemplifies high-quality testing practices through thorough memory leak detection and thread management validation.

Notable practices include:
  • Comprehensive thread state verification
  • Proper resource cleanup checks
  • Robust error handling and logging
  • Systematic ClassLoader leak detection

brettwooldridge/hikaricp

src/test/java/com/zaxxer/hikari/util/TomcatConcurrentBagLeakTest.java

            
/*
 * Copyright (C) 2017 Brett Wooldridge
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.zaxxer.hikari.util;

import com.zaxxer.hikari.pool.TestElf.FauxWebClassLoader;
import com.zaxxer.hikari.util.ConcurrentBag.IConcurrentBagEntry;
import org.junit.FixMethodOrder;
import org.junit.Test;
import org.junit.runners.MethodSorters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.ref.Reference;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.ConcurrentModificationException;
import java.util.Iterator;
import java.util.concurrent.CompletableFuture;

import static com.zaxxer.hikari.pool.TestElf.isJava11;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assume.assumeTrue;

/**
 * @author Brett Wooldridge
 */
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
public class TomcatConcurrentBagLeakTest
{
   @Test
   public void testConcurrentBagForLeaks() throws Exception
   {
      assumeTrue(!isJava11());

      ClassLoader cl = new FauxWebClassLoader();
      Class<?> clazz = cl.loadClass(this.getClass().getName() + "$FauxWebContext");
      Object fauxWebContext = clazz.getDeclaredConstructor().newInstance();

      Method createConcurrentBag = clazz.getDeclaredMethod("createConcurrentBag");
      createConcurrentBag.invoke(fauxWebContext);

      Field failureException = clazz.getDeclaredField("failureException");
      Exception ex = (Exception) failureException.get(fauxWebContext);
      assertNull(ex);
   }

   @Test
   public void testConcurrentBagForLeaks2() throws Exception
   {
      assumeTrue(!isJava11());

      ClassLoader cl = this.getClass().getClassLoader();
      Class<?> clazz = cl.loadClass(this.getClass().getName() + "$FauxWebContext");
      Object fauxWebContext = clazz.getDeclaredConstructor().newInstance();

      Method createConcurrentBag = clazz.getDeclaredMethod("createConcurrentBag");
      createConcurrentBag.invoke(fauxWebContext);

      Field failureException = clazz.getDeclaredField("failureException");
      Exception ex = (Exception) failureException.get(fauxWebContext);
      assertNotNull(ex);
   }

   public static class PoolEntry implements IConcurrentBagEntry
   {
      private int state;

      @Override
      public boolean compareAndSet(int expectState, int newState)
      {
         this.state = newState;
         return true;
      }

      @Override
      public void setState(int newState)
      {
         this.state = newState;
      }

      @Override
      public int getState()
      {
         return state;
      }
   }

   public static class FauxWebContext
   {
      private static final Logger log = LoggerFactory.getLogger(FauxWebContext.class);

      @SuppressWarnings("WeakerAccess")
      public Exception failureException;

      @SuppressWarnings({"ResultOfMethodCallIgnored"})
      public void createConcurrentBag() throws InterruptedException
      {
         try (ConcurrentBag<PoolEntry> bag = new ConcurrentBag<>(x -> CompletableFuture.completedFuture(Boolean.TRUE))) {

            PoolEntry entry = new PoolEntry();
            bag.add(entry);

            PoolEntry borrowed = bag.borrow(100, MILLISECONDS);
            bag.requite(borrowed);

            PoolEntry removed = bag.borrow(100, MILLISECONDS);
            bag.remove(removed);
         }

         checkThreadLocalsForLeaks();
      }

      private void checkThreadLocalsForLeaks()
      {
         Thread[] threads = getThreads();

         try {
            // Make the fields in the Thread class that store ThreadLocals
            // accessible
            Field threadLocalsField = Thread.class.getDeclaredField("threadLocals");
            threadLocalsField.setAccessible(true);
            Field inheritableThreadLocalsField = Thread.class.getDeclaredField("inheritableThreadLocals");
            inheritableThreadLocalsField.setAccessible(true);
            // Make the underlying array of ThreadLoad.ThreadLocalMap.Entry objects
            // accessible
            Class<?> tlmClass = Class.forName("java.lang.ThreadLocal$ThreadLocalMap");
            Field tableField = tlmClass.getDeclaredField("table");
            tableField.setAccessible(true);
            Method expungeStaleEntriesMethod = tlmClass.getDeclaredMethod("expungeStaleEntries");
            expungeStaleEntriesMethod.setAccessible(true);

            for (Thread thread : threads) {
               Object threadLocalMap;
               if (thread != null) {

                  // Clear the first map
                  threadLocalMap = threadLocalsField.get(thread);
                  if (null != threadLocalMap) {
                     expungeStaleEntriesMethod.invoke(threadLocalMap);
                     checkThreadLocalMapForLeaks(threadLocalMap, tableField);
                  }

                  // Clear the second map
                  threadLocalMap = inheritableThreadLocalsField.get(thread);
                  if (null != threadLocalMap) {
                     expungeStaleEntriesMethod.invoke(threadLocalMap);
                     checkThreadLocalMapForLeaks(threadLocalMap, tableField);
                  }
               }
            }
         }
         catch (Throwable t) {
            log.warn("Failed to check for ThreadLocal references for web application [{}]", getContextName(), t);
            failureException = new Exception();
         }
      }

      private Object getContextName()
      {
         return this.getClass().getName();
      }

      // THE FOLLOWING CODE COPIED FROM APACHE TOMCAT (2017/01/08)

      /**
      * Analyzes the given thread local map object. Also pass in the field that
      * points to the internal table to save re-calculating it on every
      * call to this method.
      */
      private void checkThreadLocalMapForLeaks(Object map, Field internalTableField) throws IllegalAccessException, NoSuchFieldException
      {
         if (map != null) {
            Object[] table = (Object[]) internalTableField.get(map);
            if (table != null) {
               for (Object obj : table) {
                  if (obj != null) {
                     boolean keyLoadedByWebapp = false;
                     boolean valueLoadedByWebapp = false;
                     // Check the key
                     Object key = ((Reference<?>) obj).get();
                     if (this.equals(key) || loadedByThisOrChild(key)) {
                        keyLoadedByWebapp = true;
                     }
                     // Check the value
                     Field valueField = obj.getClass().getDeclaredField("value");
                     valueField.setAccessible(true);
                     Object value = valueField.get(obj);
                     if (this.equals(value) || loadedByThisOrChild(value)) {
                        valueLoadedByWebapp = true;
                     }
                     if (keyLoadedByWebapp || valueLoadedByWebapp) {
                        Object[] args = new Object[5];
                        args[0] = getContextName();
                        if (key != null) {
                           args[1] = getPrettyClassName(key.getClass());
                           try {
                              args[2] = key.toString();
                           } catch (Exception e) {
                              log.warn("Unable to determine string representation of key of type [{}]", args[1], e);
                              args[2] = "Unknown";
                           }
                        }
                        if (value != null) {
                           args[3] = getPrettyClassName(value.getClass());
                           try {
                              args[4] = value.toString();
                           } catch (Exception e) {
                              log.warn("webappClassLoader.checkThreadLocalsForLeaks.badValue {}", args[3], e);
                              args[4] = "Unknown";
                           }
                        }

                        if (valueLoadedByWebapp) {
                           log.error("The web application [{}] created a ThreadLocal with key of type [{}] " +
                              "(value [{}]) and a value of type [{}] (value [{}]) but failed to remove " +
                              "it when the web application was stopped. Threads are going to be renewed " +
                              "over time to try and avoid a probable memory leak.", args);
                           failureException = new Exception();
                        } else if (value == null) {
                           log.debug("The web application [{}] created a ThreadLocal with key of type [{}] " +
                              "(value [{}]). The ThreadLocal has been correctly set to null and the " +
                              "key will be removed by GC.", args);
                           failureException = new Exception();
                        } else {
                           log.debug("The web application [{}] created a ThreadLocal with key of type [{}] " +
                              "(value [{}]) and a value of type [{}] (value [{}]). Since keys are only " +
                              "weakly held by the ThreadLocal Map this is not a memory leak.", args);
                           failureException = new Exception();
                        }
                     }
                  }
               }
            }
         }
      }

      /**
       * @param o object to test, may be null
       * @return <code>true</code> if o has been loaded by the current classloader
       * or one of its descendants.
       */
      private boolean loadedByThisOrChild(Object o) {
         if (o == null) {
            return false;
         }

         Class<?> clazz;
         if (o instanceof Class) {
            clazz = (Class<?>) o;
         } else {
            clazz = o.getClass();
         }

         ClassLoader cl = clazz.getClassLoader();
         while (cl != null) {
            if (cl == this.getClass().getClassLoader()) {
               return true;
            }
            cl = cl.getParent();
         }

         if (o instanceof Collection<?>) {
            Iterator<?> iter = ((Collection<?>) o).iterator();
            try {
               while (iter.hasNext()) {
                  Object entry = iter.next();
                  if (loadedByThisOrChild(entry)) {
                     return true;
                  }
               }
            } catch (ConcurrentModificationException e) {
               log.warn("Failed to check for ThreadLocal references for web application [{}]", getContextName(), e);
            }
         }
         return false;
      }

      /*
      * Get the set of current threads as an array.
      */
      private Thread[] getThreads()
      {
         // Get the current thread group
         ThreadGroup tg = Thread.currentThread().getThreadGroup();
         // Find the root thread group
         try {
            while (tg.getParent() != null) {
               tg = tg.getParent();
            }
         }
         catch (SecurityException se) {
            log.warn("Unable to obtain the parent for ThreadGroup [{}]. It will not be possible to check all threads for potential memory leaks", tg.getName(), se);
         }

         int threadCountGuess = tg.activeCount() + 50;
         Thread[] threads = new Thread[threadCountGuess];
         int threadCountActual = tg.enumerate(threads);
         // Make sure we don't miss any threads
         while (threadCountActual == threadCountGuess) {
            threadCountGuess *= 2;
            threads = new Thread[threadCountGuess];
            // Note tg.enumerate(Thread[]) silently ignores any threads that
            // can't fit into the array
            threadCountActual = tg.enumerate(threads);
         }

         return threads;
      }

      private String getPrettyClassName(Class<?> clazz)
      {
         String name = clazz.getCanonicalName();
         if (name == null) {
            name = clazz.getName();
         }
         return name;
      }
   }
}