Skip to content

Commit

Permalink
Properly handle nested generics and multiple wildcard type args in Ja…
Browse files Browse the repository at this point in the history
…rInfer (#1114)

Our previous code would crash on nested generic types or when multiple
unbounded wildcard type arguments were passed consecutively.
  • Loading branch information
msridhar authored Dec 24, 2024
1 parent 728bf77 commit 17df87f
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.uber.nullaway.jarinfer;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.ibm.wala.cfg.ControlFlowGraph;
Expand Down Expand Up @@ -61,11 +62,13 @@
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.attribute.FileTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.jar.JarFile;
Expand Down Expand Up @@ -526,6 +529,13 @@ private static String getAstubxSignature(IMethod mtd) {
// get types that include generic type arguments
returnType = getSourceLevelQualifiedTypeName(genericSignature.getReturnType().toString());
TypeSignature[] argTypeSigs = genericSignature.getArguments();
Verify.verify(
argTypeSigs.length == numParams,
"Mismatch in number of parameters in generic signature: %s with %s vs %s with %s",
mtd.getSignature(),
numParams,
genericSignature,
argTypeSigs.length);
for (int i = 0; i < argTypeSigs.length; i++) {
argTypes[i] = getSourceLevelQualifiedTypeName(argTypeSigs[i].toString());
}
Expand Down Expand Up @@ -591,7 +601,7 @@ private static String getSourceLevelQualifiedTypeName(String typeName) {
int idx = typeName.indexOf("<");
String baseType = typeName.substring(0, idx);
// generic type args are separated by semicolons in signature stored in bytecodes
String[] genericTypeArgs = typeName.substring(idx + 1, typeName.length() - 2).split(";");
String[] genericTypeArgs = splitTypeArgs(typeName.substring(idx + 1, typeName.length() - 2));
for (int i = 0; i < genericTypeArgs.length; i++) {
genericTypeArgs[i] = getSourceLevelQualifiedTypeName(genericTypeArgs[i]);
}
Expand All @@ -602,6 +612,51 @@ private static String getSourceLevelQualifiedTypeName(String typeName) {
}
}

/**
* Splits out the top-level type arguments from a string representing all arguments (from the
* bytecode-level signature)
*
* @param allTypeArgs string representing all type arguments
* @return array of strings representing top-level type arguments
*/
private static String[] splitTypeArgs(String allTypeArgs) {
List<String> result = new ArrayList<>();
StringBuilder currentTypeArg = new StringBuilder();
// track angle bracket depth to handle nested generic types
int angleBracketDepth = 0;

for (int i = 0; i < allTypeArgs.length(); i++) {
char c = allTypeArgs.charAt(i);
if (c == '<') {
angleBracketDepth++;
currentTypeArg.append(c);
} else if (c == '>') {
angleBracketDepth--;
currentTypeArg.append(c);
} else if (c == '*' && angleBracketDepth == 0) {
// Wildcard (not followed by semicolon)
currentTypeArg.append(c);
result.add(currentTypeArg.toString());
currentTypeArg.setLength(0);
} else if (c == ';' && angleBracketDepth == 0) {
// Split on semicolon only if not nested within <>
result.add(currentTypeArg.toString());
currentTypeArg.setLength(0);
} else {
currentTypeArg.append(c);
}
}

// there should be no extra characters left
Verify.verify(
currentTypeArg.length() == 0,
"unexpected characters left in generic type args string %s: %s",
allTypeArgs,
currentTypeArg);

return result.toArray(new String[0]);
}

private static boolean isWildcard(String typeName) {
char firstChar = typeName.charAt(0);
return firstChar == '*' || firstChar == '+' || firstChar == '-';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,27 @@ public void testMethodWithGenericParameter() throws Exception {
"}");
}

@Test
public void nestedGeneric() throws Exception {
testTemplate(
"nestedGeneric",
"generic",
"TestGeneric",
ImmutableMap.of(
"generic.TestGeneric:java.lang.String getString(generic.TestGeneric.Generic<generic.TestGeneric.Generic<java.lang.String>>)",
Sets.newHashSet(0)),
"public class TestGeneric {",
" static class Generic<T> {",
" public String foo(T t) {",
" return \"hi\";",
" }",
" }",
" public String getString(Generic<Generic<String>> g) {",
" return g.foo(null);",
" }",
"}");
}

@Test
public void wildcards() throws Exception {
testTemplate(
Expand Down Expand Up @@ -508,6 +529,40 @@ public void wildcards() throws Exception {
"}");
}

@Test
public void multiArgWildcards() throws Exception {
testTemplate(
"multiArgWildcards",
"generic",
"TestGeneric",
ImmutableMap.of(
"generic.TestGeneric:void genericMultiWildcard(java.lang.String, generic.TestGeneric.Generic<?,?>)",
Sets.newHashSet(1)),
"public class TestGeneric {",
" public abstract static class Generic<T,U> {",
" public void doNothing() {}",
" }",
" public static void genericMultiWildcard(String s, Generic<?,?> g) { g.doNothing(); };",
"}");
}

@Test
public void nestedWildcard() throws Exception {
testTemplate(
"nestedWildcard",
"generic",
"TestGeneric",
ImmutableMap.of(
"generic.TestGeneric:void nestedWildcard(generic.TestGeneric.Generic<generic.TestGeneric.Generic<?>>)",
Sets.newHashSet(0)),
"public class TestGeneric {",
" public abstract static class Generic<T> {",
" public void doNothing() {}",
" }",
" public static void nestedWildcard(Generic<Generic<?>> g) { g.doNothing(); };",
"}");
}

@Test
public void toyJARAnnotatingClasses() throws Exception {
testAnnotationInJarTemplate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ public void genericsTest() {
" g.getString(null);",
" // BUG: Diagnostic contains: passing @Nullable parameter 'null'",
" Toys.genericParam(null);",
" // BUG: Diagnostic contains: passing @Nullable parameter 'null'",
" Toys.nestedGenericParam(null);",
" }",
"}")
.doTest();
Expand All @@ -114,9 +116,14 @@ public void wildcards() {
" // BUG: Diagnostic contains: passing @Nullable parameter 'null'",
" Toys.genericWildcard(null);",
" // BUG: Diagnostic contains: passing @Nullable parameter 'null'",
" Toys.nestedGenericWildcard(null);",
" // BUG: Diagnostic contains: passing @Nullable parameter 'null'",
" Toys.genericWildcardUpper(null);",
" // BUG: Diagnostic contains: passing @Nullable parameter 'null'",
" Toys.genericWildcardLower(null);",
" // BUG: Diagnostic contains: passing @Nullable parameter 'null'",
" Toys.doubleGenericWildcard(\"\", null);",
" Toys.doubleGenericWildcardNullOk(\"\", null);",
" }",
"}")
.doTest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,18 @@ public static void genericParam(Generic<String> g) {
g.getString("hello");
}

public static void nestedGenericParam(Generic<Generic<String>> g) {
g.getString(null);
}

public static void genericWildcard(Generic<?> g) {
g.doNothing();
}

public static void nestedGenericWildcard(Generic<Generic<?>> g) {
g.doNothing();
}

public static String genericWildcardUpper(Generic<? extends String> g) {
return g.getSomething();
}
Expand All @@ -63,6 +71,20 @@ public static void genericWildcardLower(Generic<? super String> g) {
g.getString("hello");
}

public abstract static class DoubleGeneric<T, U> {
public void doNothing() {}
}

public static void doubleGenericWildcard(String s, DoubleGeneric<?, ?> g) {
g.doNothing();
}

public static void doubleGenericWildcardNullOk(String s, DoubleGeneric<?, ?> g) {
if (g != null) {
g.doNothing();
}
}

public static void main(String arg[]) throws java.io.IOException {
String s = "test string...";
Foo f = new Foo("let's");
Expand Down

0 comments on commit 17df87f

Please sign in to comment.