Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improve][Zeta] Split the classloader of task group #7580

Merged
merged 3 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/backend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ jobs:
echo "engine-e2e=$true_or_false" >> $GITHUB_OUTPUT
echo "engine-e2e_files=$file_list" >> $GITHUB_OUTPUT

api_files=`python tools/update_modules_check/check_file_updates.py ua $workspace apache/dev origin/$current_branch "seatunnel-api/**" "seatunnel-common/**" "seatunnel-config/**" "seatunnel-connectors/**" "seatunnel-core/**" "seatunnel-e2e/seatunnel-e2e-common/**" "seatunnel-formats/**" "seatunnel-plugin-discovery/**" "seatunnel-transforms-v2/**" "seatunnel-translation/**" "seatunnel-e2e/seatunnel-transforms-v2-e2e/**" "seatunnel-connectors/**" "pom.xml" "**/workflows/**" "tools/**" "seatunnel-dist/**"`
api_files=`python tools/update_modules_check/check_file_updates.py ua $workspace apache/dev origin/$current_branch "seatunnel-api/**" "seatunnel-common/**" "seatunnel-config/**" "seatunnel-engine/**" "seatunnel-core/**" "seatunnel-e2e/seatunnel-e2e-common/**" "seatunnel-formats/**" "seatunnel-plugin-discovery/**" "seatunnel-transforms-v2/**" "seatunnel-translation/**" "seatunnel-e2e/seatunnel-transforms-v2-e2e/**" "pom.xml" "**/workflows/**" "tools/**" "seatunnel-dist/**"`
true_or_false=${api_files%%$'\n'*}
file_list=${api_files#*$'\n'}
if [[ $repository_owner == 'apache' ]];then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ public class HiveSourceConfig implements Serializable {

private static final long serialVersionUID = 1L;

private final Table table;
private final CatalogTable catalogTable;
private final FileFormat fileFormat;
private final ReadStrategy readStrategy;
Expand All @@ -81,7 +80,7 @@ public HiveSourceConfig(ReadonlyConfig readonlyConfig) {
readonlyConfig
.getOptional(HdfsSourceConfigOptions.READ_PARTITIONS)
.ifPresent(this::validatePartitions);
this.table = HiveTableUtils.getTableInfo(readonlyConfig);
Table table = HiveTableUtils.getTableInfo(readonlyConfig);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid conflicts in the logical dag phase of Hive

this.hadoopConf = parseHiveHadoopConfig(readonlyConfig, table);
this.fileFormat = HiveTableUtils.parseFileFormat(table);
this.readStrategy = parseReadStrategy(table, readonlyConfig, fileFormat, hadoopConf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ private void startMaster() {

private void startWorker() {
taskExecutionService =
new TaskExecutionService(
classLoaderService, nodeEngine, nodeEngine.getProperties(), eventService);
new TaskExecutionService(classLoaderService, nodeEngine, eventService);
nodeEngine.getMetricsRegistry().registerDynamicMetricsProvider(taskExecutionService);
taskExecutionService.start();
getSlotService();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.seatunnel.engine.server.execution.TaskGroup;
import org.apache.seatunnel.engine.server.execution.TaskGroupContext;
import org.apache.seatunnel.engine.server.execution.TaskGroupLocation;
import org.apache.seatunnel.engine.server.execution.TaskGroupUtils;
import org.apache.seatunnel.engine.server.execution.TaskLocation;
import org.apache.seatunnel.engine.server.execution.TaskTracker;
import org.apache.seatunnel.engine.server.metrics.SeaTunnelMetricsContext;
Expand All @@ -65,13 +66,12 @@
import com.hazelcast.map.IMap;
import com.hazelcast.spi.impl.NodeEngineImpl;
import com.hazelcast.spi.impl.operationservice.impl.InvocationFuture;
import com.hazelcast.spi.properties.HazelcastProperties;
import lombok.Getter;
import lombok.NonNull;
import lombok.SneakyThrows;

import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -149,7 +149,6 @@ public class TaskExecutionService implements DynamicMetricsProvider {
public TaskExecutionService(
ClassLoaderService classLoaderService,
NodeEngineImpl nodeEngine,
HazelcastProperties properties,
EventService eventService) {
seaTunnelConfig = ConfigProvider.locateAndGetSeaTunnelConfig();
this.hzInstanceName = nodeEngine.getHazelcastInstance().getName();
Expand Down Expand Up @@ -282,33 +281,50 @@ public TaskDeployState deployTask(@NonNull TaskGroupImmutableInformation taskImm
taskImmutableInfo.getExecutionId()));
TaskGroup taskGroup = null;
try {
Set<ConnectorJarIdentifier> connectorJarIdentifiers =
List<Set<ConnectorJarIdentifier>> connectorJarIdentifiersList =
taskImmutableInfo.getConnectorJarIdentifiers();
Set<URL> jars = new HashSet<>();
ClassLoader classLoader;
if (!CollectionUtils.isEmpty(connectorJarIdentifiers)) {
// Prioritize obtaining the jar package file required for the current task execution
// from the local, if it does not exist locally, it will be downloaded from the
// master node.
jars =
serverConnectorPackageClient.getConnectorJarFromLocal(
connectorJarIdentifiers);
} else if (!CollectionUtils.isEmpty(taskImmutableInfo.getJars())) {
jars = taskImmutableInfo.getJars();
}
classLoader =
classLoaderService.getClassLoader(
taskImmutableInfo.getJobId(), Lists.newArrayList(jars));
if (jars.isEmpty()) {
taskGroup =
nodeEngine.getSerializationService().toObject(taskImmutableInfo.getGroup());
} else {
taskGroup =
CustomClassLoadedObject.deserializeWithCustomClassLoader(
nodeEngine.getSerializationService(),
classLoader,
taskImmutableInfo.getGroup());
List<Data> taskData = taskImmutableInfo.getTasksData();
ConcurrentHashMap<Long, ClassLoader> classLoaders = new ConcurrentHashMap<>();
List<Task> tasks = new ArrayList<>();
ConcurrentHashMap<Long, Collection<URL>> taskJars = new ConcurrentHashMap<>();
for (int i = 0; i < taskData.size(); i++) {
Set<URL> jars = new HashSet<>();
Set<ConnectorJarIdentifier> connectorJarIdentifiers =
connectorJarIdentifiersList.get(i);
if (!CollectionUtils.isEmpty(connectorJarIdentifiers)) {
// Prioritize obtaining the jar package file required for the current task
// execution
// from the local, if it does not exist locally, it will be downloaded from the
// master node.
jars =
serverConnectorPackageClient.getConnectorJarFromLocal(
connectorJarIdentifiers);
} else if (!CollectionUtils.isEmpty(taskImmutableInfo.getJars().get(i))) {
jars = taskImmutableInfo.getJars().get(i);
}
ClassLoader classLoader =
classLoaderService.getClassLoader(
taskImmutableInfo.getJobId(), Lists.newArrayList(jars));
Task task;
if (jars.isEmpty()) {
task = nodeEngine.getSerializationService().toObject(taskData.get(i));
} else {
task =
CustomClassLoadedObject.deserializeWithCustomClassLoader(
nodeEngine.getSerializationService(),
classLoader,
taskData.get(i));
}
tasks.add(task);
classLoaders.put(task.getTaskID(), classLoader);
taskJars.put(task.getTaskID(), jars);
}
taskGroup =
TaskGroupUtils.createTaskGroup(
taskImmutableInfo.getTaskGroupType(),
taskImmutableInfo.getTaskGroupLocation(),
taskImmutableInfo.getTaskGroupName(),
tasks);

logger.info(
String.format(
Expand All @@ -322,7 +338,7 @@ public TaskDeployState deployTask(@NonNull TaskGroupImmutableInformation taskImm
"TaskGroupLocation: %s already exists",
taskGroup.getTaskGroupLocation()));
}
deployLocalTask(taskGroup, classLoader, jars);
deployLocalTask(taskGroup, classLoaders, taskJars);
return TaskDeployState.success();
}
} catch (Throwable t) {
Expand All @@ -337,15 +353,10 @@ public TaskDeployState deployTask(@NonNull TaskGroupImmutableInformation taskImm
}
}

@Deprecated
public PassiveCompletableFuture<TaskExecutionState> deployLocalTask(
@NonNull TaskGroup taskGroup) {
return deployLocalTask(
taskGroup, Thread.currentThread().getContextClassLoader(), emptyList());
}

public PassiveCompletableFuture<TaskExecutionState> deployLocalTask(
@NonNull TaskGroup taskGroup, @NonNull ClassLoader classLoader, Collection<URL> jars) {
@NonNull TaskGroup taskGroup,
@NonNull ConcurrentHashMap<Long, ClassLoader> classLoaders,
ConcurrentHashMap<Long, Collection<URL>> jars) {
CompletableFuture<TaskExecutionState> resultFuture = new CompletableFuture<>();
try {
taskGroup.init();
Expand Down Expand Up @@ -389,7 +400,7 @@ public PassiveCompletableFuture<TaskExecutionState> deployLocalTask(
}));
executionContexts.put(
taskGroup.getTaskGroupLocation(),
new TaskGroupContext(taskGroup, classLoader, jars));
new TaskGroupContext(taskGroup, classLoaders, jars));
cancellationFutures.put(taskGroup.getTaskGroupLocation(), cancellationFuture);
submitThreadShareTask(executionTracker, byCooperation.get(true));
submitBlockingTask(executionTracker, byCooperation.get(false));
Expand Down Expand Up @@ -591,7 +602,7 @@ private void updateMetricsContextInImap() {
}
});
});
if (localMap.size() > 0) {
if (!localMap.isEmpty()) {
boolean lockedIMap = false;
try {
lockedIMap =
Expand Down Expand Up @@ -669,7 +680,8 @@ public void run() {
ClassLoader classLoader =
executionContexts
.get(taskGroupExecutionTracker.taskGroup.getTaskGroupLocation())
.getClassLoader();
.getClassLoaders()
.get(tracker.task.getTaskID());
ClassLoader oldClassLoader = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(classLoader);
final Task t = tracker.task;
Expand Down Expand Up @@ -728,16 +740,16 @@ public final class CooperativeTaskWorker implements Runnable {
public AtomicReference<TaskTracker> exclusiveTaskTracker = new AtomicReference<>();
final TaskCallTimer timer;
private Thread myThread;
public LinkedBlockingDeque<TaskTracker> taskqueue;
public LinkedBlockingDeque<TaskTracker> taskQueue;
private Future<?> thisTaskFuture;
private BlockingQueue<Future<?>> futureBlockingQueue;

public CooperativeTaskWorker(
LinkedBlockingDeque<TaskTracker> taskqueue,
LinkedBlockingDeque<TaskTracker> taskQueue,
RunBusWorkSupplier runBusWorkSupplier,
BlockingQueue<Future<?>> futureBlockingQueue) {
logger.info(String.format("Created new BusWork : %s", this.hashCode()));
this.taskqueue = taskqueue;
this.taskQueue = taskQueue;
this.timer = new TaskCallTimer(50, keep, runBusWorkSupplier, this);
this.futureBlockingQueue = futureBlockingQueue;
}
Expand All @@ -752,7 +764,7 @@ public void run() {
TaskTracker taskTracker =
null != exclusiveTaskTracker.get()
? exclusiveTaskTracker.get()
: taskqueue.takeFirst();
: taskQueue.takeFirst();
TaskGroupExecutionTracker taskGroupExecutionTracker =
taskTracker.taskGroupExecutionTracker;
if (taskGroupExecutionTracker.executionCompletedExceptionally()) {
Expand All @@ -777,7 +789,8 @@ public void run() {
myThread.setContextClassLoader(
executionContexts
.get(taskGroupExecutionTracker.taskGroup.getTaskGroupLocation())
.getClassLoader());
.getClassLoaders()
.get(taskTracker.task.getTaskID()));
call = taskTracker.task.call();
synchronized (timer) {
timer.timerStop();
Expand Down Expand Up @@ -819,7 +832,7 @@ public void run() {
// Task is not completed. Put task to the end of the queue
// If the current work has an exclusive tracker, it will not be put back
if (null == exclusiveTaskTracker.get()) {
taskqueue.offer(taskTracker);
taskQueue.offer(taskTracker);
}
}
}
Expand All @@ -840,7 +853,7 @@ public RunBusWorkSupplier(
}

public boolean runNewBusWork(boolean checkTaskQueue) {
if (!checkTaskQueue || taskQueue.size() > 0) {
if (!checkTaskQueue || !taskQueue.isEmpty()) {
BlockingQueue<Future<?>> futureBlockingQueue = new LinkedBlockingQueue<>();
CooperativeTaskWorker cooperativeTaskWorker =
new CooperativeTaskWorker(taskQueue, this, futureBlockingQueue);
Expand All @@ -867,7 +880,7 @@ public final class TaskGroupExecutionTracker {

private final AtomicBoolean isCancel = new AtomicBoolean(false);

@Getter private Map<Long, Future<?>> currRunningTaskFuture = new ConcurrentHashMap<>();
private final Map<Long, Future<?>> currRunningTaskFuture = new ConcurrentHashMap<>();

TaskGroupExecutionTracker(
@NonNull CompletableFuture<Void> cancellationFuture,
Expand Down Expand Up @@ -972,8 +985,10 @@ void taskDone(Task task) {

private void recycleClassLoader(TaskGroupLocation taskGroupLocation) {
TaskGroupContext context = executionContexts.get(taskGroupLocation);
executionContexts.get(taskGroupLocation).setClassLoader(null);
classLoaderService.releaseClassLoader(taskGroupLocation.getJobId(), context.getJars());
executionContexts.get(taskGroupLocation).setClassLoaders(null);
for (Collection<URL> jars : context.getJars().values()) {
classLoaderService.releaseClassLoader(taskGroupLocation.getJobId(), jars);
}
}

boolean executionCompletedExceptionally() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ public void runInternal() throws Exception {
.getExecutionContext(taskLocation.getTaskGroupLocation());
Task task = groupContext.getTaskGroup().getTask(taskLocation.getTaskID());
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(groupContext.getClassLoader());
Thread.currentThread()
.setContextClassLoader(
groupContext.getClassLoader(taskLocation.getTaskID()));

task.notifyCheckpointEnd(checkpointId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ public void runInternal() throws Exception {
.getExecutionContext(taskLocation.getTaskGroupLocation());
Task task = groupContext.getTaskGroup().getTask(taskLocation.getTaskID());
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(groupContext.getClassLoader());
Thread.currentThread()
.setContextClassLoader(
groupContext.getClassLoader(taskLocation.getTaskID()));
if (successful) {
task.notifyCheckpointComplete(checkpointId);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ public void runInternal() throws Exception {
() -> {
Thread.currentThread()
.setContextClassLoader(
groupContext.getClassLoader());
groupContext.getClassLoader(
task.getTaskID()));
try {
log.debug(
"NotifyTaskRestoreOperation.restoreState "
Expand Down
Loading
Loading