Validating ThreadLocal Memory Management in HikariCP
This test suite validates memory leak detection and thread management in HikariCP’s ConcurrentBag implementation, focusing on Tomcat-specific scenarios. It ensures proper cleanup of ThreadLocal resources and verifies the bag’s behavior under various concurrent operations.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
brettwooldridge/hikaricp
src/test/java/com/zaxxer/hikari/util/TomcatConcurrentBagLeakTest.java
/*
* Copyright (C) 2017 Brett Wooldridge
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.zaxxer.hikari.util;
import com.zaxxer.hikari.pool.TestElf.FauxWebClassLoader;
import com.zaxxer.hikari.util.ConcurrentBag.IConcurrentBagEntry;
import org.junit.FixMethodOrder;
import org.junit.Test;
import org.junit.runners.MethodSorters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.ref.Reference;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.ConcurrentModificationException;
import java.util.Iterator;
import java.util.concurrent.CompletableFuture;
import static com.zaxxer.hikari.pool.TestElf.isJava11;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assume.assumeTrue;
/**
* @author Brett Wooldridge
*/
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
public class TomcatConcurrentBagLeakTest
{
@Test
public void testConcurrentBagForLeaks() throws Exception
{
assumeTrue(!isJava11());
ClassLoader cl = new FauxWebClassLoader();
Class<?> clazz = cl.loadClass(this.getClass().getName() + "$FauxWebContext");
Object fauxWebContext = clazz.getDeclaredConstructor().newInstance();
Method createConcurrentBag = clazz.getDeclaredMethod("createConcurrentBag");
createConcurrentBag.invoke(fauxWebContext);
Field failureException = clazz.getDeclaredField("failureException");
Exception ex = (Exception) failureException.get(fauxWebContext);
assertNull(ex);
}
@Test
public void testConcurrentBagForLeaks2() throws Exception
{
assumeTrue(!isJava11());
ClassLoader cl = this.getClass().getClassLoader();
Class<?> clazz = cl.loadClass(this.getClass().getName() + "$FauxWebContext");
Object fauxWebContext = clazz.getDeclaredConstructor().newInstance();
Method createConcurrentBag = clazz.getDeclaredMethod("createConcurrentBag");
createConcurrentBag.invoke(fauxWebContext);
Field failureException = clazz.getDeclaredField("failureException");
Exception ex = (Exception) failureException.get(fauxWebContext);
assertNotNull(ex);
}
public static class PoolEntry implements IConcurrentBagEntry
{
private int state;
@Override
public boolean compareAndSet(int expectState, int newState)
{
this.state = newState;
return true;
}
@Override
public void setState(int newState)
{
this.state = newState;
}
@Override
public int getState()
{
return state;
}
}
public static class FauxWebContext
{
private static final Logger log = LoggerFactory.getLogger(FauxWebContext.class);
@SuppressWarnings("WeakerAccess")
public Exception failureException;
@SuppressWarnings({"ResultOfMethodCallIgnored"})
public void createConcurrentBag() throws InterruptedException
{
try (ConcurrentBag<PoolEntry> bag = new ConcurrentBag<>(x -> CompletableFuture.completedFuture(Boolean.TRUE))) {
PoolEntry entry = new PoolEntry();
bag.add(entry);
PoolEntry borrowed = bag.borrow(100, MILLISECONDS);
bag.requite(borrowed);
PoolEntry removed = bag.borrow(100, MILLISECONDS);
bag.remove(removed);
}
checkThreadLocalsForLeaks();
}
private void checkThreadLocalsForLeaks()
{
Thread[] threads = getThreads();
try {
// Make the fields in the Thread class that store ThreadLocals
// accessible
Field threadLocalsField = Thread.class.getDeclaredField("threadLocals");
threadLocalsField.setAccessible(true);
Field inheritableThreadLocalsField = Thread.class.getDeclaredField("inheritableThreadLocals");
inheritableThreadLocalsField.setAccessible(true);
// Make the underlying array of ThreadLoad.ThreadLocalMap.Entry objects
// accessible
Class<?> tlmClass = Class.forName("java.lang.ThreadLocal$ThreadLocalMap");
Field tableField = tlmClass.getDeclaredField("table");
tableField.setAccessible(true);
Method expungeStaleEntriesMethod = tlmClass.getDeclaredMethod("expungeStaleEntries");
expungeStaleEntriesMethod.setAccessible(true);
for (Thread thread : threads) {
Object threadLocalMap;
if (thread != null) {
// Clear the first map
threadLocalMap = threadLocalsField.get(thread);
if (null != threadLocalMap) {
expungeStaleEntriesMethod.invoke(threadLocalMap);
checkThreadLocalMapForLeaks(threadLocalMap, tableField);
}
// Clear the second map
threadLocalMap = inheritableThreadLocalsField.get(thread);
if (null != threadLocalMap) {
expungeStaleEntriesMethod.invoke(threadLocalMap);
checkThreadLocalMapForLeaks(threadLocalMap, tableField);
}
}
}
}
catch (Throwable t) {
log.warn("Failed to check for ThreadLocal references for web application [{}]", getContextName(), t);
failureException = new Exception();
}
}
private Object getContextName()
{
return this.getClass().getName();
}
// THE FOLLOWING CODE COPIED FROM APACHE TOMCAT (2017/01/08)
/**
* Analyzes the given thread local map object. Also pass in the field that
* points to the internal table to save re-calculating it on every
* call to this method.
*/
private void checkThreadLocalMapForLeaks(Object map, Field internalTableField) throws IllegalAccessException, NoSuchFieldException
{
if (map != null) {
Object[] table = (Object[]) internalTableField.get(map);
if (table != null) {
for (Object obj : table) {
if (obj != null) {
boolean keyLoadedByWebapp = false;
boolean valueLoadedByWebapp = false;
// Check the key
Object key = ((Reference<?>) obj).get();
if (this.equals(key) || loadedByThisOrChild(key)) {
keyLoadedByWebapp = true;
}
// Check the value
Field valueField = obj.getClass().getDeclaredField("value");
valueField.setAccessible(true);
Object value = valueField.get(obj);
if (this.equals(value) || loadedByThisOrChild(value)) {
valueLoadedByWebapp = true;
}
if (keyLoadedByWebapp || valueLoadedByWebapp) {
Object[] args = new Object[5];
args[0] = getContextName();
if (key != null) {
args[1] = getPrettyClassName(key.getClass());
try {
args[2] = key.toString();
} catch (Exception e) {
log.warn("Unable to determine string representation of key of type [{}]", args[1], e);
args[2] = "Unknown";
}
}
if (value != null) {
args[3] = getPrettyClassName(value.getClass());
try {
args[4] = value.toString();
} catch (Exception e) {
log.warn("webappClassLoader.checkThreadLocalsForLeaks.badValue {}", args[3], e);
args[4] = "Unknown";
}
}
if (valueLoadedByWebapp) {
log.error("The web application [{}] created a ThreadLocal with key of type [{}] " +
"(value [{}]) and a value of type [{}] (value [{}]) but failed to remove " +
"it when the web application was stopped. Threads are going to be renewed " +
"over time to try and avoid a probable memory leak.", args);
failureException = new Exception();
} else if (value == null) {
log.debug("The web application [{}] created a ThreadLocal with key of type [{}] " +
"(value [{}]). The ThreadLocal has been correctly set to null and the " +
"key will be removed by GC.", args);
failureException = new Exception();
} else {
log.debug("The web application [{}] created a ThreadLocal with key of type [{}] " +
"(value [{}]) and a value of type [{}] (value [{}]). Since keys are only " +
"weakly held by the ThreadLocal Map this is not a memory leak.", args);
failureException = new Exception();
}
}
}
}
}
}
}
/**
* @param o object to test, may be null
* @return <code>true</code> if o has been loaded by the current classloader
* or one of its descendants.
*/
private boolean loadedByThisOrChild(Object o) {
if (o == null) {
return false;
}
Class<?> clazz;
if (o instanceof Class) {
clazz = (Class<?>) o;
} else {
clazz = o.getClass();
}
ClassLoader cl = clazz.getClassLoader();
while (cl != null) {
if (cl == this.getClass().getClassLoader()) {
return true;
}
cl = cl.getParent();
}
if (o instanceof Collection<?>) {
Iterator<?> iter = ((Collection<?>) o).iterator();
try {
while (iter.hasNext()) {
Object entry = iter.next();
if (loadedByThisOrChild(entry)) {
return true;
}
}
} catch (ConcurrentModificationException e) {
log.warn("Failed to check for ThreadLocal references for web application [{}]", getContextName(), e);
}
}
return false;
}
/*
* Get the set of current threads as an array.
*/
private Thread[] getThreads()
{
// Get the current thread group
ThreadGroup tg = Thread.currentThread().getThreadGroup();
// Find the root thread group
try {
while (tg.getParent() != null) {
tg = tg.getParent();
}
}
catch (SecurityException se) {
log.warn("Unable to obtain the parent for ThreadGroup [{}]. It will not be possible to check all threads for potential memory leaks", tg.getName(), se);
}
int threadCountGuess = tg.activeCount() + 50;
Thread[] threads = new Thread[threadCountGuess];
int threadCountActual = tg.enumerate(threads);
// Make sure we don't miss any threads
while (threadCountActual == threadCountGuess) {
threadCountGuess *= 2;
threads = new Thread[threadCountGuess];
// Note tg.enumerate(Thread[]) silently ignores any threads that
// can't fit into the array
threadCountActual = tg.enumerate(threads);
}
return threads;
}
private String getPrettyClassName(Class<?> clazz)
{
String name = clazz.getCanonicalName();
if (name == null) {
name = clazz.getName();
}
return name;
}
}
}