Skip to content
This repository has been archived by the owner on Oct 26, 2024. It is now read-only.

Commit

Permalink
perf(YouTube): Reduce memory requirement for prefix tree searching (#501
Browse files Browse the repository at this point in the history
)
  • Loading branch information
LisoUseInAIKyrios authored Oct 17, 2023
1 parent bd307e4 commit f5add51
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,15 @@ private static void findAsciiStrings(StringBuilder builder, byte[] buffer) {

static {
for (Filter filter : filters) {
filterGroupLists(pathSearchTree, filter, filter.pathFilterGroupList);
filterGroupLists(identifierSearchTree, filter, filter.identifierFilterGroupList);
filterGroupLists(pathSearchTree, filter, filter.pathFilterGroupList);
}

LogHelper.printDebug(() -> "Using: "
+ pathSearchTree.numberOfPatterns() + " path filters"
+ " (" + pathSearchTree.getEstimatedMemorySize() + " KB), "
+ identifierSearchTree.numberOfPatterns() + " identifier filters"
+ " (" + identifierSearchTree.getEstimatedMemorySize() + " KB)");
+ " (" + identifierSearchTree.getEstimatedMemorySize() + " KB), "
+ pathSearchTree.numberOfPatterns() + " path filters"
+ " (" + pathSearchTree.getEstimatedMemorySize() + " KB)");
}

private static <T> void filterGroupLists(TrieSearch<T> pathSearchTree,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,17 @@
public final class ByteTrieSearch extends TrieSearch<byte[]> {

private static final class ByteTrieNode extends TrieNode<byte[]> {
TrieNode<byte[]> createNode() {
return new ByteTrieNode();
ByteTrieNode() {
super();
}
ByteTrieNode(char nodeCharacterValue) {
super(nodeCharacterValue);
}
@Override
TrieNode<byte[]> createNode(char nodeCharacterValue) {
return new ByteTrieNode(nodeCharacterValue);
}
@Override
char getCharValue(byte[] text, int index) {
return (char) text[index];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@
public final class StringTrieSearch extends TrieSearch<String> {

private static final class StringTrieNode extends TrieNode<String> {
TrieNode<String> createNode() {
return new StringTrieNode();
StringTrieNode() {
super();
}
StringTrieNode(char nodeCharacterValue) {
super(nodeCharacterValue);
}
@Override
TrieNode<String> createNode(char nodeValue) {
return new StringTrieNode(nodeValue);
}
@Override
char getCharValue(String text, int index) {
return text.charAt(index);
}
Expand Down
116 changes: 101 additions & 15 deletions app/src/main/java/app/revanced/integrations/utils/TrieSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,31 @@ boolean matches(TrieNode<T> enclosingNode, // Used only for the get character me
}

static abstract class TrieNode<T> {
/**
* Dummy value used for root node. Value can be anything as it's never referenced.
*/
private static final char ROOT_NODE_CHARACTER_VALUE = 0; // ASCII null character.

// Support only ASCII letters/numbers/symbols and filter out all control characters.
private static final char MIN_VALID_CHAR = 32; // Space character.
private static final char MAX_VALID_CHAR = 126; // 127 = delete character.
private static final int NUMBER_OF_CHILDREN = MAX_VALID_CHAR - MIN_VALID_CHAR + 1;

/**
* How much to expand the children array when resizing.
*/
private static final int CHILDREN_ARRAY_INCREASE_SIZE_INCREMENT = 2;
private static final int CHILDREN_ARRAY_MAX_SIZE = MAX_VALID_CHAR - MIN_VALID_CHAR + 1;

private static boolean isInvalidRange(char character) {
return character < MIN_VALID_CHAR || character > MAX_VALID_CHAR;
}

/**
* Character this node represents.
* This field is ignored for the root node (which does not represent any character).
*/
private final char nodeValue;

/**
* A compressed graph path that represents the remaining pattern characters of a single child node.
*
Expand All @@ -91,6 +107,24 @@ private static boolean isInvalidRange(char character) {

/**
* All child nodes. Only present if no compressed leaf exist.
*
* Array is dynamically increased in size as needed,
* and uses perfect hashing for the elements it contains.
*
* So if the array contains a given character,
* the character will always map to the node with index: (character % arraySize).
*
* Elements not contained can collide with elements the array does contain,
* so must compare the nodes character value.
*
* Alternatively this array could be a sorted and densely packed array,
* and lookup is done using binary search.
* That would save a small amount of memory because there's no null children entries,
* but would give a worst case search of O(nlog(m)) where n is the number of
* characters in the searched text and m is the maximum size of the sorted character arrays.
* Using a hash table array always gives O(n) search time.
* The memory usage here is very small (all Litho filters use ~10KB of memory),
* so the more performant hash implementation is chosen.
*/
@Nullable
private TrieNode<T>[] children;
Expand All @@ -101,6 +135,13 @@ private static boolean isInvalidRange(char character) {
@Nullable
private List<TriePatternMatchedCallback<T>> endOfPatternCallback;

TrieNode() {
this.nodeValue = ROOT_NODE_CHARACTER_VALUE;
}
TrieNode(char nodeCharacterValue) {
this.nodeValue = nodeCharacterValue;
}

/**
* @param pattern Pattern to add.
* @param patternLength Length of the pattern.
Expand All @@ -121,7 +162,7 @@ private void addPattern(@NonNull T pattern, int patternLength, int patternIndex,
// Recursively call back into this method and push the existing leaf down 1 level.
if (children != null) throw new IllegalStateException();
//noinspection unchecked
children = new TrieNode[NUMBER_OF_CHILDREN];
children = new TrieNode[1];
TrieCompressedPath<T> temp = leaf;
leaf = null;
addPattern(temp.pattern, temp.patternLength, temp.patternStartIndex, temp.callback);
Expand All @@ -130,19 +171,65 @@ private void addPattern(@NonNull T pattern, int patternLength, int patternIndex,
leaf = new TrieCompressedPath<>(pattern, patternLength, patternIndex, callback);
return;
}
char character = getCharValue(pattern, patternIndex);
final char character = getCharValue(pattern, patternIndex);
if (isInvalidRange(character)) {
throw new IllegalArgumentException("invalid character at index " + patternIndex + ": " + pattern);
}
character -= MIN_VALID_CHAR; // Adjust to the array range.
TrieNode<T> child = children[character];
final int arrayIndex = hashIndexForTableSize(children.length, character);
TrieNode<T> child = children[arrayIndex];
if (child == null) {
child = createNode();
children[character] = child;
child = createNode(character);
children[arrayIndex] = child;
} else if (child.nodeValue != character) {
// Hash collision. Resize the table until perfect hashing is found.
child = createNode(character);
expandChildArray(child);
}
child.addPattern(pattern, patternLength, patternIndex + 1, callback);
}

/**
* Resizes the children table until all nodes hash to exactly one array index.
* Worse case, this will resize the array to {@link #CHILDREN_ARRAY_MAX_SIZE} elements.
*/
private void expandChildArray(TrieNode<T> child) {
int replacementArraySize = Objects.requireNonNull(children).length;
while (true) {
replacementArraySize += CHILDREN_ARRAY_INCREASE_SIZE_INCREMENT;
//noinspection unchecked
TrieNode<T>[] replacement = new TrieNode[replacementArraySize];
addNodeToArray(replacement, child);
boolean collision = false;
for (TrieNode<T> existingChild : children) {
if (existingChild != null) {
if (!addNodeToArray(replacement, existingChild)) {
collision = true;
break;
}
}
}
if (collision) {
if (replacementArraySize > CHILDREN_ARRAY_MAX_SIZE) throw new IllegalStateException();
continue;
}
children = replacement;
return;
}
}

private static <T> boolean addNodeToArray(TrieNode<T>[] array, TrieNode<T> childToAdd) {
final int insertIndex = hashIndexForTableSize(array.length, childToAdd.nodeValue);
if (array[insertIndex] != null ) {
return false; // Collision.
}
array[insertIndex] = childToAdd;
return true;
}

private static int hashIndexForTableSize(int arraySize, char nodeValue) {
return (nodeValue - MIN_VALID_CHAR) % arraySize;
}

/**
* @param searchText Text to search for patterns in.
* @param searchTextLength Length of the search text.
Expand Down Expand Up @@ -170,18 +257,17 @@ private boolean matches(T searchText, int searchTextLength, int searchTextIndex,
if (children == null) {
return false; // Reached a graph end point and there's no further patterns to search.
}

if (searchTextIndex == searchTextLength) {
return false; // Reached end of the search text and found no matches.
}

char character = getCharValue(searchText, searchTextIndex);
final char character = getCharValue(searchText, searchTextIndex);
if (isInvalidRange(character)) {
return false; // Not an ASCII letter/number/symbol.
}
character -= MIN_VALID_CHAR; // Adjust to the array range.
TrieNode<T> child = children[character];
if (child == null) {
final int arrayIndex = hashIndexForTableSize(children.length, character);
TrieNode<T> child = children[arrayIndex];
if (child == null || child.nodeValue != character) {
return false;
}
return child.matches(searchText, searchTextLength, searchTextIndex + 1,
Expand All @@ -194,15 +280,15 @@ private boolean matches(T searchText, int searchTextLength, int searchTextIndex,
* @return Estimated number of memory pointers used, starting from this node and including all children.
*/
private int estimatedNumberOfPointersUsed() {
int numberOfPointers = 3; // Number of fields in this class.
int numberOfPointers = 4; // Number of fields in this class.
if (leaf != null) {
numberOfPointers += 4; // Number of fields in leaf node.
}
if (endOfPatternCallback != null) {
numberOfPointers += endOfPatternCallback.size();
}
if (children != null) {
numberOfPointers += NUMBER_OF_CHILDREN;
numberOfPointers += children.length;
for (TrieNode<T> child : children) {
if (child != null) {
numberOfPointers += child.estimatedNumberOfPointersUsed();
Expand All @@ -212,7 +298,7 @@ private int estimatedNumberOfPointersUsed() {
return numberOfPointers;
}

abstract TrieNode<T> createNode();
abstract TrieNode<T> createNode(char nodeValue);
abstract char getCharValue(T text, int index);
}

Expand Down

0 comments on commit f5add51

Please sign in to comment.