Skip to content

Commit

Permalink
fix(wave): NVRTC remote execution
Browse files Browse the repository at this point in the history
Differential Revision: D67990177
  • Loading branch information
Yuhta authored and facebook-github-bot committed Jan 9, 2025
1 parent c965e7c commit 6a858b2
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 75 deletions.
2 changes: 2 additions & 0 deletions velox/experimental/wave/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ target_link_libraries(
CUDA::nvrtc
CUDA::cudart)

target_compile_definitions(velox_wave_common PRIVATE VELOX_OSS_BUILD=1)

if(${VELOX_BUILD_TESTING})
add_subdirectory(tests)
endif()
73 changes: 13 additions & 60 deletions velox/experimental/wave/common/Compile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
#include "velox/experimental/wave/jit/Headers.h"
#include "velox/external/jitify/jitify.hpp"

#ifndef VELOX_OSS_BUILD
#include "velox/facebook/NvrtcUtil.h"
#endif

namespace facebook::velox::wave {

void nvrtcCheck(nvrtcResult result) {
Expand Down Expand Up @@ -75,6 +79,14 @@ void addFlag(
data.push_back(std::move(str));
}

#ifdef VELOX_OSS_BUILD
void getDefaultNvrtcOptions(std::vector<std::string>& data) {
constexpr const char* kUsrLocalCuda = "/usr/local/cuda/include";
LOG(INFO) << "Using " << kUsrLocalCuda;
addFlag("-I", kUsrLocalCuda, strlen(kUsrLocalCuda), data);
}
#endif

// Gets compiler options from the environment and appends them to 'data'.
void getNvrtcOptions(std::vector<std::string>& data) {
const char* includes = getenv("WAVE_NVRTC_INCLUDE_PATH");
Expand All @@ -90,66 +102,7 @@ void getNvrtcOptions(std::vector<std::string>& data) {
includes = end + 1;
}
} else {
std::string currentPath = std::filesystem::current_path().c_str();
LOG(INFO) << "Looking for Cuda includes. cwd=" << currentPath
<< " Cuda=" << __CUDA_API_VER_MAJOR__ << "."
<< __CUDA_API_VER_MINOR__;
auto pathCStr = currentPath.c_str();
if (auto fbsource = strstr(pathCStr, "fbsource")) {
// fbcode has cuda includes in fbsource/third-party/cuda/...
try {
auto fbsourcePath =
std::string(pathCStr, fbsource - pathCStr + strlen("fbsource")) +
"/third-party/cuda";
LOG(INFO) << "Guessing fbsource path =" << fbsourcePath;
auto tempPath = fmt::format("/tmp/cuda.{}", getpid());
auto command = fmt::format(
"(cd {}; du |grep \"{}\\.{}.*x64-linux.*/cuda$\" |grep -v thrust) >{}",
fbsourcePath,
__CUDA_API_VER_MAJOR__,
__CUDA_API_VER_MINOR__,
tempPath);
LOG(INFO) << "Running " << command;
system(command.c_str());
std::ifstream result(tempPath);
std::string line;
if (!std::getline(result, line)) {
LOG(ERROR)
<< "Cuda includes matching build version not found in fbcode/third-party. Looking for latest cuda.";
command = fmt::format(
"(cd {}; du |grep \"{}\.*x64-linux.*/cuda$\" |grep -v thrust | sort -r) >{}",
fbsourcePath,
__CUDA_API_VER_MAJOR__,
tempPath);
LOG(INFO) << "Running " << command;
system(command.c_str());
std::ifstream result(tempPath);
if (!std::getline(result, line)) {
LOG(ERROR) << "Did not find any cuda with the same major version";
return;
}
}

LOG(INFO) << "Got cuda line: " << line;
// Now trim the size and the trailing /cuda from the line.
const char* start = strstr(line.c_str(), "./");
if (!start) {
LOG(ERROR) << "Line " << line << " does not have ./";
return;
}
auto path = fbsourcePath + "/" + (start + 2);
// We add the cwd + the found path minus the trailing /cuda.
addFlag("-I", path.c_str(), path.size() - 5, data);
} catch (const std::exception& e) {
LOG(ERROR) << "Failed to infer fbcode Cuda include path: " << e.what();
}
} else {
addFlag(
"-I",
"/usr/local/cuda/include",
strlen("/usr/local/cuda/include"),
data);
}
getDefaultNvrtcOptions(data);
}
const char* flags = getenv("WAVE_NVRTC_FLAGS");
if (flags && strlen(flags)) {
Expand Down
4 changes: 2 additions & 2 deletions velox/experimental/wave/common/StringView.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

#pragma once

#include <cassert>
#include <cstdint>
#include <assert.h>
#include <stdint.h>
#include "velox/experimental/wave/common/CompilerDefines.h"

namespace facebook::velox::wave {
Expand Down
2 changes: 1 addition & 1 deletion velox/experimental/wave/exec/Aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Aggregation::Aggregation(
const std::shared_ptr<aggregation::AggregateFunctionRegistry>&
functionRegistry)
: WaveOperator(state, node.outputType(), node.id()),
arena_(&state.arena()),
arena_(state.arena()),
functionRegistry_(functionRegistry) {
VELOX_CHECK(node.step() == core::AggregationNode::Step::kSingle);
VELOX_CHECK(node.preGroupedKeys().empty());
Expand Down
8 changes: 2 additions & 6 deletions velox/experimental/wave/exec/ToWave.h
Original file line number Diff line number Diff line change
Expand Up @@ -745,18 +745,14 @@ class CompileState {
addExprSet(const exec::ExprSet& set, int32_t begin, int32_t end);
std::vector<std::vector<ProgramPtr>> makeLevels(int32_t startIndex);

GpuArena& arena() const {
return *arena_;
GpuArena* arena() const {
return arena_.get();
}

int numOperators() const {
return operators_.size();
}

GpuArena& arena() {
return *arena_;
}

std::stringstream& generated() {
return generated_;
}
Expand Down
6 changes: 3 additions & 3 deletions velox/experimental/wave/exec/Wave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1175,8 +1175,8 @@ void Program::callUpdateStatus(WaveStream& stream, AdvanceResult& advance) {
#define IN_OPERAND(member) \
physicalInst->member = operandIndex(abstractInst->member)

void Program::prepareForDevice(GpuArena& arena) {
arena_ = &arena;
void Program::prepareForDevice(GpuArena* arena) {
arena_ = arena;
if (kernel_) {
return;
}
Expand Down Expand Up @@ -1270,7 +1270,7 @@ void Program::prepareForDevice(GpuArena& arena) {
"OpCode {}", static_cast<int32_t>(instruction->opCode));
}
sortSlots();
deviceData_ = arena.allocate<char>(
deviceData_ = arena->allocate<char>(
codeSize + literalArea_.size() + sizeof(ThreadBlockProgram));
uintptr_t end = reinterpret_cast<uintptr_t>(
deviceData_->as<char>() + deviceData_->size());
Expand Down
2 changes: 1 addition & 1 deletion velox/experimental/wave/exec/Wave.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ class Program : public std::enable_shared_from_this<Program> {

// Initializes executableImage and relocation information and places
// the result on device.
void prepareForDevice(GpuArena& arena);
void prepareForDevice(GpuArena* arena);

std::unique_ptr<Executable> getExecutable(
int32_t maxRows,
Expand Down
4 changes: 2 additions & 2 deletions velox/experimental/wave/jit/Headers.h
Original file line number Diff line number Diff line change
Expand Up @@ -2178,8 +2178,8 @@ const char* velox_experimental_wave_common_StringView_h =
"\n"
"#pragma once\n"
"\n"
"#include <cassert>\n"
"#include <cstdint>\n"
"#include <assert.h>\n"
"#include <stdint.h>\n"
"#include \"velox/experimental/wave/common/CompilerDefines.h\"\n"
"\n"
"namespace facebook::velox::wave {\n"
Expand Down

0 comments on commit 6a858b2

Please sign in to comment.