Skip to content

Commit

Permalink
[Improve][Core] Improve base on plugin name of lookup strategy (apach…
Browse files Browse the repository at this point in the history
…e#7278)

* [bug][plugin-discovery] fix multi plugin discovery

* [bug][plugin-discovery] optimize code

---------

Co-authored-by: wangchao <[email protected]>
  • Loading branch information
corgy-w and corgy-w authored Jul 29, 2024
1 parent 3ccc6a8 commit 21c4f52
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -430,17 +432,16 @@ public boolean accept(File pathname) {
if (ArrayUtils.isEmpty(targetPluginFiles)) {
return Optional.empty();
}
if (targetPluginFiles.length > 1) {
throw new IllegalArgumentException(
"Found multiple plugin jar: "
+ Arrays.stream(targetPluginFiles)
.map(File::getPath)
.collect(Collectors.joining(","))
+ " for pluginIdentifier: "
+ pluginIdentifier);
}
try {
URL pluginJarPath = targetPluginFiles[0].toURI().toURL();
URL pluginJarPath;
if (targetPluginFiles.length == 1) {
pluginJarPath = targetPluginFiles[0].toURI().toURL();
} else {
pluginJarPath =
findMostSimlarPluginJarFile(targetPluginFiles, pluginJarPrefix)
.toURI()
.toURL();
}
log.info("Discovery plugin jar for: {} at: {}", pluginIdentifier, pluginJarPath);
return Optional.of(pluginJarPath);
} catch (MalformedURLException e) {
Expand All @@ -451,4 +452,105 @@ public boolean accept(File pathname) {
return Optional.empty();
}
}

private static File findMostSimlarPluginJarFile(
File[] targetPluginFiles, String pluginJarPrefix) {
String splitRegex = "\\-|\\_|\\.";
double maxSimlarity = -Integer.MAX_VALUE;
int mostSimlarPluginJarFileIndex = -1;
for (int i = 0; i < targetPluginFiles.length; i++) {
File file = targetPluginFiles[i];
String fileName = file.getName();
double similarity =
CosineSimilarityUtil.cosineSimilarity(pluginJarPrefix, fileName, splitRegex);
if (similarity > maxSimlarity) {
maxSimlarity = similarity;
mostSimlarPluginJarFileIndex = i;
}
}
return targetPluginFiles[mostSimlarPluginJarFileIndex];
}

static class CosineSimilarityUtil {
public static double cosineSimilarity(String textA, String textB, String splitRegrex) {
Set<String> words1 =
new HashSet<>(Arrays.asList(textA.toLowerCase().split(splitRegrex)));
Set<String> words2 =
new HashSet<>(Arrays.asList(textB.toLowerCase().split(splitRegrex)));
int[] termFrequency1 = calculateTermFrequencyVector(textA, words1, splitRegrex);
int[] termFrequency2 = calculateTermFrequencyVector(textB, words2, splitRegrex);
return calculateCosineSimilarity(termFrequency1, termFrequency2);
}

private static int[] calculateTermFrequencyVector(
String text, Set<String> words, String splitRegrex) {
int[] termFrequencyVector = new int[words.size()];
String[] textArray = text.toLowerCase().split(splitRegrex);
List<String> orderedWords = new ArrayList<String>();
words.clear();
for (String word : textArray) {
if (!words.contains(word)) {
orderedWords.add(word);
words.add(word);
}
}
for (String word : textArray) {
if (words.contains(word)) {
int index = 0;
for (String w : orderedWords) {
if (w.equals(word)) {
termFrequencyVector[index]++;
break;
}
index++;
}
}
}
return termFrequencyVector;
}

private static double calculateCosineSimilarity(int[] vectorA, int[] vectorB) {
double dotProduct = 0.0;
double magnitudeA = 0.0;
double magnitudeB = 0.0;
int vectorALength = vectorA.length;
int vectorBLength = vectorB.length;
if (vectorALength < vectorBLength) {
int[] vectorTemp = new int[vectorBLength];
for (int i = 0; i < vectorB.length; i++) {
if (i <= vectorALength - 1) {
vectorTemp[i] = vectorA[i];
} else {
vectorTemp[i] = 0;
}
}
vectorA = vectorTemp;
}
if (vectorALength > vectorBLength) {
int[] vectorTemp = new int[vectorALength];
for (int i = 0; i < vectorA.length; i++) {
if (i <= vectorBLength - 1) {
vectorTemp[i] = vectorB[i];
} else {
vectorTemp[i] = 0;
}
}
vectorB = vectorTemp;
}
for (int i = 0; i < vectorA.length; i++) {
dotProduct += vectorA[i] * vectorB[i];
magnitudeA += Math.pow(vectorA[i], 2);
magnitudeB += Math.pow(vectorB[i], 2);
}

magnitudeA = Math.sqrt(magnitudeA);
magnitudeB = Math.sqrt(magnitudeB);

if (magnitudeA == 0 || magnitudeB == 0) {
return 0.0; // Avoid dividing by 0
} else {
return dotProduct / (magnitudeA * magnitudeB);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@
import com.google.common.collect.Lists;

import java.io.IOException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@DisabledOnOs(OS.WINDOWS)
class SeaTunnelSourcePluginDiscoveryTest {
Expand All @@ -47,7 +50,10 @@ class SeaTunnelSourcePluginDiscoveryTest {
private static final List<Path> pluginJars =
Lists.newArrayList(
Paths.get(seatunnelHome, "connectors", "connector-http-jira.jar"),
Paths.get(seatunnelHome, "connectors", "connector-http.jar"));
Paths.get(seatunnelHome, "connectors", "connector-http.jar"),
Paths.get(seatunnelHome, "connectors", "connector-kafka.jar"),
Paths.get(seatunnelHome, "connectors", "connector-kafka-alcs.jar"),
Paths.get(seatunnelHome, "connectors", "connector-kafka-blcs.jar"));

@BeforeEach
public void before() throws IOException {
Expand All @@ -67,12 +73,25 @@ void getPluginBaseClass() {
List<PluginIdentifier> pluginIdentifiers =
Lists.newArrayList(
PluginIdentifier.of("seatunnel", PluginType.SOURCE.getType(), "HttpJira"),
PluginIdentifier.of("seatunnel", PluginType.SOURCE.getType(), "HttpBase"));
PluginIdentifier.of("seatunnel", PluginType.SOURCE.getType(), "HttpBase"),
PluginIdentifier.of("seatunnel", PluginType.SOURCE.getType(), "Kafka"),
PluginIdentifier.of("seatunnel", PluginType.SINK.getType(), "Kafka-Blcs"));
SeaTunnelSourcePluginDiscovery seaTunnelSourcePluginDiscovery =
new SeaTunnelSourcePluginDiscovery();
Assertions.assertThrows(
IllegalArgumentException.class,
() -> seaTunnelSourcePluginDiscovery.getPluginJarPaths(pluginIdentifiers));
Assertions.assertIterableEquals(
Stream.of(
Paths.get(seatunnelHome, "connectors", "connector-http-jira.jar")
.toString(),
Paths.get(seatunnelHome, "connectors", "connector-http.jar")
.toString(),
Paths.get(seatunnelHome, "connectors", "connector-kafka.jar")
.toString(),
Paths.get(seatunnelHome, "connectors", "connector-kafka-blcs.jar")
.toString())
.collect(Collectors.toList()),
seaTunnelSourcePluginDiscovery.getPluginJarPaths(pluginIdentifiers).stream()
.map(URL::getPath)
.collect(Collectors.toList()));
}

@AfterEach
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,10 @@
seatunnel.source.HttpBase = connector-http
seatunnel.sink.HttpBase = connector-http
seatunnel.source.HttpJira = connector-http-jira
seatunnel.sink.HttpJira = connector-http-jira
seatunnel.sink.HttpJira = connector-http-jira
seatunnel.source.Kafka = connector-kafka
seatunnel.sink.Kafka = connector-kafka
seatunnel.source.Kafka-Alcs = connector-kafka-alcs
seatunnel.sink.Kafka-Alcs = connector-kafka-alcs
seatunnel.source.Kafka-Blcs = connector-kafka-blcs
seatunnel.sink.Kafka-Blcs = connector-kafka-blcs

0 comments on commit 21c4f52

Please sign in to comment.