diff --git a/src/accountsdb/db.zig b/src/accountsdb/db.zig index db4e6b424..1e09d3c6c 100644 --- a/src/accountsdb/db.zig +++ b/src/accountsdb/db.zig @@ -25,7 +25,7 @@ const SnapshotFields = @import("../accountsdb/snapshots.zig").SnapshotFields; const BankIncrementalSnapshotPersistence = @import("../accountsdb/snapshots.zig").BankIncrementalSnapshotPersistence; const Bank = @import("bank.zig").Bank; const readDirectory = @import("../utils/directory.zig").readDirectory; -const SnapshotPaths = @import("../accountsdb/snapshots.zig").SnapshotPaths; +const SnapshotFiles = @import("../accountsdb/snapshots.zig").SnapshotFiles; const AllSnapshotFields = @import("../accountsdb/snapshots.zig").AllSnapshotFields; const parallelUnpackZstdTarBall = @import("../accountsdb/snapshots.zig").parallelUnpackZstdTarBall; const Logger = @import("../trace/log.zig").Logger; @@ -962,8 +962,10 @@ fn loadTestAccountsDB(use_disk: bool) !struct { AccountsDB, AllSnapshotFields } true, ); - var snapshot_paths = try SnapshotPaths.find(allocator, dir_path); - var snapshots = try AllSnapshotFields.fromPaths(allocator, dir_path, snapshot_paths); + var snapshot_files = try SnapshotFiles.find(allocator, dir_path); + defer snapshot_files.deinit(allocator); + + var snapshots = try AllSnapshotFields.fromFiles(allocator, dir_path, snapshot_files); defer { allocator.free(snapshots.full_path); if (snapshots.incremental_path) |inc_path| { diff --git a/src/accountsdb/snapshots.zig b/src/accountsdb/snapshots.zig index 349e56381..21530ff1f 100644 --- a/src/accountsdb/snapshots.zig +++ b/src/accountsdb/snapshots.zig @@ -713,14 +713,21 @@ pub const StatusCache = struct { } }; -pub const FullSnapshotPath = struct { - path: []const u8, +/// information on a full snapshot including the filename, slot, and hash +pub const FullSnapshotFileInfo = struct { + filename: []const u8, slot: Slot, hash: []const u8, + const Self = @This(); + + pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { + allocator.free(self.filename); + } + /// matches with the regex: r"^snapshot-(?P[[:digit:]]+)-(?P[[:alnum:]]+)\.(?Ptar\.zst)$"; - pub fn fromPath(path: []const u8) !FullSnapshotPath { - var ext_parts = std.mem.splitSequence(u8, path, "."); + pub fn fromString(filename: []const u8) !Self { + var ext_parts = std.mem.splitSequence(u8, filename, "."); const stem = ext_parts.next() orelse return error.InvalidSnapshotPath; var extn = ext_parts.rest(); @@ -738,20 +745,27 @@ pub const FullSnapshotPath = struct { var hash = parts.next() orelse return error.InvalidSnapshotPath; - return FullSnapshotPath{ .path = path, .slot = slot, .hash = hash }; + return .{ .filename = filename, .slot = slot, .hash = hash }; } }; -pub const IncrementalSnapshotPath = struct { - path: []const u8, +/// information on an incremental snapshot including the filename, base slot (full snapshot), slot, and hash +pub const IncrementalSnapshotFileInfo = struct { + filename: []const u8, // this references the full snapshot slot base_slot: Slot, slot: Slot, hash: []const u8, + const Self = @This(); + + pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { + allocator.free(self.filename); + } + /// matches against regex: r"^incremental-snapshot-(?P[[:digit:]]+)-(?P[[:digit:]]+)-(?P[[:alnum:]]+)\.(?Ptar\.zst)$"; - pub fn fromPath(path: []const u8) !IncrementalSnapshotPath { - var ext_parts = std.mem.splitSequence(u8, path, "."); + pub fn fromString(filename: []const u8) !Self { + var ext_parts = std.mem.splitSequence(u8, filename, "."); const stem = ext_parts.next() orelse return error.InvalidSnapshotPath; var extn = ext_parts.rest(); @@ -776,8 +790,8 @@ pub const IncrementalSnapshotPath = struct { var hash = parts.next() orelse return error.InvalidSnapshotPath; - return IncrementalSnapshotPath{ - .path = path, + return .{ + .filename = filename, .slot = slot, .base_slot = base_slot, .hash = hash, @@ -785,12 +799,14 @@ pub const IncrementalSnapshotPath = struct { } }; -pub const SnapshotPaths = struct { - full_snapshot: FullSnapshotPath, - incremental_snapshot: ?IncrementalSnapshotPath, +pub const SnapshotFiles = struct { + full_snapshot: FullSnapshotFileInfo, + incremental_snapshot: ?IncrementalSnapshotFileInfo, + + const Self = @This(); /// finds existing snapshots (full and matching incremental) by looking for .tar.zstd files - pub fn find(allocator: std.mem.Allocator, snapshot_dir: []const u8) !SnapshotPaths { + pub fn find(allocator: std.mem.Allocator, snapshot_dir: []const u8) !Self { var snapshot_dir_iter = try std.fs.cwd().openIterableDir(snapshot_dir, .{}); defer snapshot_dir_iter.close(); @@ -802,34 +818,34 @@ pub const SnapshotPaths = struct { } // find the snapshots - var maybe_latest_full_snapshot: ?FullSnapshotPath = null; + var maybe_latest_full_snapshot: ?FullSnapshotFileInfo = null; var count: usize = 0; for (filenames.items) |filename| { - const snap_path = FullSnapshotPath.fromPath(filename) catch continue; - if (count == 0 or snap_path.slot > maybe_latest_full_snapshot.?.slot) { - maybe_latest_full_snapshot = snap_path; + const snapshot = FullSnapshotFileInfo.fromString(filename) catch continue; + if (count == 0 or snapshot.slot > maybe_latest_full_snapshot.?.slot) { + maybe_latest_full_snapshot = snapshot; } count += 1; } - var latest_full_snapshot = maybe_latest_full_snapshot orelse return error.NoFullSnapshotFound; + var latest_full_snapshot = maybe_latest_full_snapshot orelse return error.NoFullSnapshotFileInfoFound; // clone the name so we can deinit the full array - latest_full_snapshot.path = try snapshot_dir_iter.dir.realpathAlloc(allocator, latest_full_snapshot.path); + latest_full_snapshot.filename = try snapshot_dir_iter.dir.realpathAlloc(allocator, latest_full_snapshot.filename); count = 0; - var maybe_latest_incremental_snapshot: ?IncrementalSnapshotPath = null; + var maybe_latest_incremental_snapshot: ?IncrementalSnapshotFileInfo = null; for (filenames.items) |filename| { - const snap_path = IncrementalSnapshotPath.fromPath(filename) catch continue; + const snapshot = IncrementalSnapshotFileInfo.fromString(filename) catch continue; // need to match the base slot - if (snap_path.base_slot == latest_full_snapshot.slot and (count == 0 or + if (snapshot.base_slot == latest_full_snapshot.slot and (count == 0 or // this unwrap is safe because count > 0 - snap_path.slot > maybe_latest_incremental_snapshot.?.slot)) + snapshot.slot > maybe_latest_incremental_snapshot.?.slot)) { - maybe_latest_incremental_snapshot = snap_path; + maybe_latest_incremental_snapshot = snapshot; } count += 1; } if (maybe_latest_incremental_snapshot) |*latest_incremental_snapshot| { - latest_incremental_snapshot.path = try snapshot_dir_iter.dir.realpathAlloc(allocator, latest_incremental_snapshot.path); + latest_incremental_snapshot.filename = try snapshot_dir_iter.dir.realpathAlloc(allocator, latest_incremental_snapshot.filename); } return .{ @@ -838,22 +854,28 @@ pub const SnapshotPaths = struct { }; } - pub fn deinit(self: *SnapshotPaths, allocator: std.mem.Allocator) void { - allocator.free(self.full_snapshot.path); - if (self.incremental_snapshot) |incremental_snapshot| { - allocator.free(incremental_snapshot.path); + pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { + self.full_snapshot.deinit(allocator); + if (self.incremental_snapshot) |*incremental_snapshot| { + incremental_snapshot.deinit(allocator); } } }; +/// contains all fields from a snapshot (full and incremental) pub const AllSnapshotFields = struct { full: SnapshotFields, incremental: ?SnapshotFields, - paths: SnapshotPaths, was_collapsed: bool = false, // used for deinit() - pub fn fromPaths(allocator: std.mem.Allocator, snapshot_dir: []const u8, paths: SnapshotPaths) !struct { - all_fields: AllSnapshotFields, + const Self = @This(); + + pub fn fromFiles( + allocator: std.mem.Allocator, + snapshot_dir_str: []const u8, + files: SnapshotFiles, + ) !struct { + all_fields: Self, full_path: []const u8, incremental_path: ?[]const u8, } { @@ -861,7 +883,7 @@ pub const AllSnapshotFields = struct { const full_metadata_path = try std.fmt.allocPrint( allocator, "{s}/{s}/{d}/{d}", - .{ snapshot_dir, "snapshots", paths.full_snapshot.slot, paths.full_snapshot.slot }, + .{ snapshot_dir_str, "snapshots", files.full_snapshot.slot, files.full_snapshot.slot }, ); var full_fields = try SnapshotFields.readFromFilePath( @@ -871,11 +893,11 @@ pub const AllSnapshotFields = struct { var incremental_fields: ?SnapshotFields = null; var incremental_metadata_path: ?[]const u8 = null; - if (paths.incremental_snapshot) |incremental_snapshot_path| { + if (files.incremental_snapshot) |incremental_snapshot_path| { incremental_metadata_path = try std.fmt.allocPrint( allocator, "{s}/{s}/{d}/{d}", - .{ snapshot_dir, "snapshots", incremental_snapshot_path.slot, incremental_snapshot_path.slot }, + .{ snapshot_dir_str, "snapshots", incremental_snapshot_path.slot, incremental_snapshot_path.slot }, ); incremental_fields = try SnapshotFields.readFromFilePath( @@ -884,10 +906,9 @@ pub const AllSnapshotFields = struct { ); } - const all_fields = .{ + const all_fields: Self = .{ .full = full_fields, .incremental = incremental_fields, - .paths = paths, }; return .{ @@ -903,7 +924,7 @@ pub const AllSnapshotFields = struct { /// this will 1) modify the incremental snapshot account map /// and 2) the returned snapshot heap fields will still point to the incremental snapshot /// (so be sure not to deinit it while still using the returned snapshot) - pub fn collapse(self: *AllSnapshotFields) !SnapshotFields { + pub fn collapse(self: *Self) !SnapshotFields { // nothing to collapse if (self.incremental == null) return self.full; @@ -940,7 +961,7 @@ pub const AllSnapshotFields = struct { return snapshot; } - pub fn deinit(self: *AllSnapshotFields, allocator: std.mem.Allocator) void { + pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { if (!self.was_collapsed) { self.full.deinit(allocator); if (self.incremental) |inc| { @@ -955,7 +976,6 @@ pub const AllSnapshotFields = struct { bincode.free(allocator, inc.accounts_db_fields.rooted_slot_hashes); } } - self.paths.deinit(allocator); } }; @@ -993,21 +1013,21 @@ pub fn parallelUnpackZstdTarBall( test "core.accounts_db.snapshots: test full snapshot path parsing" { const full_snapshot_path = "snapshot-269-EAHHZCVccCdAoCXH8RWxvv9edcwjY2boqni9MJuh3TCn.tar.zst"; - const snapshot_info = try FullSnapshotPath.fromPath(full_snapshot_path); + const snapshot_info = try FullSnapshotFileInfo.fromString(full_snapshot_path); try std.testing.expect(snapshot_info.slot == 269); try std.testing.expect(std.mem.eql(u8, snapshot_info.hash, "EAHHZCVccCdAoCXH8RWxvv9edcwjY2boqni9MJuh3TCn")); - try std.testing.expect(std.mem.eql(u8, snapshot_info.path, full_snapshot_path)); + try std.testing.expect(std.mem.eql(u8, snapshot_info.filename, full_snapshot_path)); } test "core.accounts_db.snapshots: test incremental snapshot path parsing" { const path = "incremental-snapshot-269-307-4JLFzdaaqkSrmHs55bBDhZrQjHYZvqU1vCcQ5mP22pdB.tar.zst"; - const snapshot_info = try IncrementalSnapshotPath.fromPath(path); + const snapshot_info = try IncrementalSnapshotFileInfo.fromString(path); try std.testing.expect(snapshot_info.base_slot == 269); try std.testing.expect(snapshot_info.slot == 307); try std.testing.expect(std.mem.eql(u8, snapshot_info.hash, "4JLFzdaaqkSrmHs55bBDhZrQjHYZvqU1vCcQ5mP22pdB")); - try std.testing.expect(std.mem.eql(u8, snapshot_info.path, path)); + try std.testing.expect(std.mem.eql(u8, snapshot_info.filename, path)); } test "core.accounts_db.snapshotss: parse status cache" { diff --git a/src/bincode/bincode.zig b/src/bincode/bincode.zig index 4038fa6b5..a5716f3f3 100644 --- a/src/bincode/bincode.zig +++ b/src/bincode/bincode.zig @@ -838,7 +838,7 @@ test "bincode: custom field serialization" { var buf: [1000]u8 = undefined; var out = try writeToSlice(&buf, foo, Params{}); - std.debug.print("{any}", .{out}); + // std.debug.print("{any}", .{out}); try std.testing.expect(out[out.len - 1] != 20); // skip worked var size = try getSerializedSize(std.testing.allocator, foo, Params{}); @@ -846,7 +846,7 @@ test "bincode: custom field serialization" { var r = try readFromSlice(std.testing.allocator, Foo, out, Params{}); defer free(std.testing.allocator, r); - std.debug.print("{any}", .{r}); + // std.debug.print("{any}", .{r}); try std.testing.expect(r.accounts.len == foo.accounts.len); try std.testing.expect(r.txs.len == foo.txs.len); diff --git a/src/cmd/cmd.zig b/src/cmd/cmd.zig index 6f37bd56c..bb4bbd395 100644 --- a/src/cmd/cmd.zig +++ b/src/cmd/cmd.zig @@ -1,25 +1,39 @@ const std = @import("std"); -const cli = @import("zig-cli"); const base58 = @import("base58-zig"); +const cli = @import("zig-cli"); const dns = @import("zigdig"); -const enumFromName = @import("../utils/types.zig").enumFromName; -const getOrInitIdentity = @import("./helpers.zig").getOrInitIdentity; -const ContactInfo = @import("../gossip/data.zig").ContactInfo; -const SOCKET_TAG_GOSSIP = @import("../gossip/data.zig").SOCKET_TAG_GOSSIP; -const Logger = @import("../trace/log.zig").Logger; -const Level = @import("../trace/level.zig").Level; -const io = std.io; -const Pubkey = @import("../core/pubkey.zig").Pubkey; -const SocketAddr = @import("../net/net.zig").SocketAddr; -const echo = @import("../net/echo.zig"); -const GossipService = @import("../gossip/service.zig").GossipService; -const servePrometheus = @import("../prometheus/http.zig").servePrometheus; -const globalRegistry = @import("../prometheus/registry.zig").globalRegistry; -const Registry = @import("../prometheus/registry.zig").Registry; -const getWallclockMs = @import("../gossip/data.zig").getWallclockMs; -const IpAddr = @import("../lib.zig").net.IpAddr; - -const SnapshotPaths = @import("../accountsdb/snapshots.zig").SnapshotPaths; +const network = @import("zig-network"); +const sig = @import("../lib.zig"); +const helpers = @import("helpers.zig"); + +const Atomic = std.atomic.Atomic; +const KeyPair = std.crypto.sign.Ed25519.KeyPair; +const Random = std.rand.Random; +const Socket = network.Socket; + +const ContactInfo = sig.gossip.ContactInfo; +const GossipService = sig.gossip.GossipService; +const IpAddr = sig.net.IpAddr; +const Level = sig.trace.Level; +const Logger = sig.trace.Logger; +const Pubkey = sig.core.Pubkey; +const Registry = sig.prometheus.Registry; +const RepairService = sig.tvu.RepairService; +const RepairPeerProvider = sig.tvu.RepairPeerProvider; +const RepairRequester = sig.tvu.RepairRequester; +const ShredReceiver = sig.tvu.ShredReceiver; +const SocketAddr = sig.net.SocketAddr; + +const enumFromName = sig.utils.enumFromName; +const getOrInitIdentity = helpers.getOrInitIdentity; +const globalRegistry = sig.prometheus.globalRegistry; +const getWallclockMs = sig.gossip.getWallclockMs; +const requestIpEcho = sig.net.requestIpEcho; +const servePrometheus = sig.prometheus.servePrometheus; + +const socket_tag = sig.gossip.socket_tag; + +const SnapshotFiles = @import("../accountsdb/snapshots.zig").SnapshotFiles; const parallelUnpackZstdTarBall = @import("../accountsdb/snapshots.zig").parallelUnpackZstdTarBall; const AllSnapshotFields = @import("../accountsdb/snapshots.zig").AllSnapshotFields; const AccountsDB = @import("../accountsdb/db.zig").AccountsDB; @@ -65,6 +79,22 @@ var gossip_port_option = cli.Option{ .value_name = "Gossip Port", }; +var repair_port_option = cli.Option{ + .long_name = "repair-port", + .help = "The port to run tvu repair listener - default: 8002", + .value = cli.OptionValue{ .int = 8002 }, + .required = false, + .value_name = "Repair Port", +}; + +var test_repair_option = cli.Option{ + .long_name = "test-repair-for-slot", + .help = "Set a slot here to repeatedly send repair requests for shreds from this slot. This is only intended for use during short-lived tests of the repair service. Do not set this during normal usage.", + .value = cli.OptionValue{ .int = null }, + .required = false, + .value_name = "slot number", +}; + var gossip_entrypoints_option = cli.Option{ .long_name = "entrypoint", .help = "gossip address of the entrypoint validators", @@ -179,10 +209,28 @@ var app = &cli.App{ , .action = gossip, .options = &.{ + &gossip_host.option, + &gossip_port_option, + &gossip_entrypoints_option, + &gossip_spy_node_option, + &gossip_dump_option, + }, + }, + &cli.Command{ + .name = "validator", + .help = "Run validator", + .description = + \\Start a full Solana validator client. + , + .action = validator, + .options = &.{ + &gossip_host.option, &gossip_port_option, &gossip_entrypoints_option, &gossip_spy_node_option, &gossip_dump_option, + &repair_port_option, + &test_repair_option, }, }, &cli.Command{ @@ -203,7 +251,7 @@ var app = &cli.App{ }, }; -// prints (and creates if DNE) pubkey in ~/.sig/identity.key +/// entrypoint to print (and create if DNE) pubkey in ~/.sig/identity.key fn identity(_: []const []const u8) !void { var logger = Logger.init(gpa_allocator, try enumFromName(Level, log_level_option.value.string.?)); defer logger.deinit(); @@ -215,21 +263,189 @@ fn identity(_: []const []const u8) !void { try std.io.getStdErr().writer().print("Identity: {s}\n", .{pubkey[0..size]}); } -// gossip entrypoint +/// entrypoint to run only gossip fn gossip(_: []const []const u8) !void { - var logger = Logger.init(gpa_allocator, try enumFromName(Level, log_level_option.value.string.?)); + var logger = try spawnLogger(); defer logger.deinit(); - logger.spawn(); + const metrics_thread = try spawnMetrics(logger); + defer metrics_thread.detach(); - const metrics_thread = try spawnMetrics(gpa_allocator, logger); + var exit = std.atomic.Atomic(bool).init(false); + const my_keypair = try getOrInitIdentity(gpa_allocator, logger); + const entrypoints = try getEntrypoints(logger); + defer entrypoints.deinit(); + const my_data = try getMyDataFromIpEcho(logger, entrypoints.items); - var my_keypair = try getOrInitIdentity(gpa_allocator, logger); + var gossip_service = try initGossip( + logger, + my_keypair, + &exit, + entrypoints, + my_data.shred_version, + my_data.ip, + &.{}, + ); + defer gossip_service.deinit(); + + var handle = try spawnGossip(&gossip_service); + handle.join(); +} + +/// entrypoint to run a full solana validator +fn validator(_: []const []const u8) !void { + var logger = try spawnLogger(); + defer logger.deinit(); + const metrics_thread = try spawnMetrics(logger); + defer metrics_thread.detach(); + var rand = std.rand.DefaultPrng.init(@bitCast(std.time.timestamp())); + var exit = std.atomic.Atomic(bool).init(false); + const my_keypair = try getOrInitIdentity(gpa_allocator, logger); + const entrypoints = try getEntrypoints(logger); + defer entrypoints.deinit(); + const ip_echo_data = try getMyDataFromIpEcho(logger, entrypoints.items); + + const repair_port: u16 = @intCast(repair_port_option.value.int.?); + + var gossip_service = try initGossip( + logger, + my_keypair, + &exit, + entrypoints, + ip_echo_data.shred_version, // TODO atomic owned at top level? or owned by gossip is good? + ip_echo_data.ip, + &.{.{ .tag = socket_tag.REPAIR, .port = repair_port }}, + ); + defer gossip_service.deinit(); + var gossip_handle = try spawnGossip(&gossip_service); + + var repair_socket = try Socket.create(network.AddressFamily.ipv4, network.Protocol.udp); + try repair_socket.bindToPort(repair_port); + try repair_socket.setReadTimeout(sig.net.SOCKET_TIMEOUT); + + var repair_svc = try initRepair(logger, &my_keypair, &exit, rand.random(), &gossip_service, &repair_socket); + defer repair_svc.deinit(); + var repair_handle = try std.Thread.spawn(.{}, RepairService.run, .{&repair_svc}); + + var shred_receiver = ShredReceiver{ + .allocator = gpa_allocator, + .keypair = &my_keypair, + .exit = &exit, + .logger = logger, + .socket = &repair_socket, + }; + var shred_receive_handle = try std.Thread.spawn(.{}, ShredReceiver.run, .{&shred_receiver}); + + gossip_handle.join(); + repair_handle.join(); + shred_receive_handle.join(); +} + +/// Initialize an instance of GossipService and configure with CLI arguments +fn initGossip( + logger: Logger, + my_keypair: KeyPair, + exit: *Atomic(bool), + entrypoints: std.ArrayList(SocketAddr), + shred_version: u16, + gossip_host_ip: IpAddr, + sockets: []const struct { tag: u8, port: u16 }, +) !GossipService { var gossip_port: u16 = @intCast(gossip_port_option.value.int.?); + logger.infof("gossip host: {any}", .{gossip_host_ip}); logger.infof("gossip port: {d}", .{gossip_port}); + // setup contact info + var my_pubkey = Pubkey.fromPublicKey(&my_keypair.public_key); + var contact_info = ContactInfo.init(gpa_allocator, my_pubkey, getWallclockMs(), 0); + try contact_info.setSocket(socket_tag.GOSSIP, SocketAddr.init(gossip_host_ip, gossip_port)); + for (sockets) |s| try contact_info.setSocket(s.tag, SocketAddr.init(gossip_host_ip, s.port)); + contact_info.shred_version = shred_version; + + return try GossipService.init( + gpa_allocator, + contact_info, + my_keypair, + entrypoints, + exit, + logger, + ); +} + +fn initRepair( + logger: Logger, + my_keypair: *const KeyPair, + exit: *Atomic(bool), + random: Random, + gossip_service: *GossipService, + socket: *Socket, +) !RepairService { + var peer_provider = try RepairPeerProvider.init( + gpa_allocator, + random, + &gossip_service.gossip_table_rw, + Pubkey.fromPublicKey(&my_keypair.public_key), + &gossip_service.my_shred_version, + ); + return RepairService{ + .allocator = gpa_allocator, + .requester = RepairRequester{ + .allocator = gpa_allocator, + .rng = random, + .udp_send_socket = socket, + .keypair = my_keypair, + .logger = logger, + }, + .peer_provider = peer_provider, + .logger = logger, + .exit = exit, + .slot_to_request = if (test_repair_option.value.int) |n| @intCast(n) else null, + }; +} + +/// Spawn a thread to run gossip and configure with CLI arguments +fn spawnGossip(gossip_service: *GossipService) std.Thread.SpawnError!std.Thread { + const spy_node = gossip_spy_node_option.value.bool; + return try std.Thread.spawn( + .{}, + GossipService.run, + .{ gossip_service, spy_node, gossip_dump_option.value.bool }, + ); +} + +/// determine our shred version and ip. in the solana-labs client, the shred version +/// comes from the snapshot, and ip echo is only used to validate it. +fn getMyDataFromIpEcho( + logger: Logger, + entrypoints: []SocketAddr, +) !struct { shred_version: u16, ip: IpAddr } { + var my_ip_from_entrypoint: ?IpAddr = null; + const my_shred_version = loop: for (entrypoints) |entrypoint| { + if (requestIpEcho(gpa_allocator, entrypoint.toAddress(), .{})) |response| { + if (my_ip_from_entrypoint == null) my_ip_from_entrypoint = response.address; + if (response.shred_version) |shred_version| { + var addr_str = entrypoint.toString(); + logger.infof( + "shred version: {} - from entrypoint ip echo: {s}", + .{ shred_version.value, addr_str[0][0..addr_str[1]] }, + ); + break shred_version.value; + } + } else |_| {} + } else { + logger.warn("could not get a shred version from an entrypoint"); + break :loop 0; + }; + const my_ip = try gossip_host.get() orelse my_ip_from_entrypoint orelse IpAddr.newIpv4(127, 0, 0, 1); + logger.infof("my ip: {}", .{my_ip}); + return .{ + .shred_version = my_shred_version, + .ip = my_ip, + }; +} + +fn getEntrypoints(logger: Logger) !std.ArrayList(SocketAddr) { var entrypoints = std.ArrayList(SocketAddr).init(gpa_allocator); - defer entrypoints.deinit(); if (gossip_entrypoints_option.value.string_list) |entrypoints_strs| { for (entrypoints_strs) |entrypoint| { var socket_addr = SocketAddr.parse(entrypoint) catch brk: { @@ -281,65 +497,22 @@ fn gossip(_: []const []const u8) !void { } logger.infof("entrypoints: {s}", .{entrypoint_string[0..stream.pos]}); - // determine our shred version and ip. in the solana-labs client, the shred version - // comes from the snapshot, and ip echo is only used to validate it. - var my_ip_from_entrypoint: ?IpAddr = null; - const my_shred_version = loop: for (entrypoints.items) |entrypoint| { - if (echo.requestIpEcho(gpa_allocator, entrypoint.toAddress(), .{})) |response| { - if (my_ip_from_entrypoint == null) my_ip_from_entrypoint = response.address; - if (response.shred_version) |shred_version| { - var addr_str = entrypoint.toString(); - logger.infof( - "shred version: {} - from entrypoint ip echo: {s}", - .{ shred_version.value, addr_str[0][0..addr_str[1]] }, - ); - break shred_version.value; - } - } else |_| {} - } else { - logger.warn("could not get a shred version from an entrypoint"); - break :loop 0; - }; - const my_ip = try gossip_host.get() orelse my_ip_from_entrypoint orelse IpAddr.newIpv4(127, 0, 0, 1); - logger.infof("my ip: {}", .{my_ip}); - - // setup contact info - var my_pubkey = Pubkey.fromPublicKey(&my_keypair.public_key); - var contact_info = ContactInfo.init(gpa_allocator, my_pubkey, getWallclockMs(), 0); - contact_info.shred_version = my_shred_version; - var gossip_address = SocketAddr.init(my_ip, gossip_port); - try contact_info.setSocket(SOCKET_TAG_GOSSIP, gossip_address); - - var exit = std.atomic.Atomic(bool).init(false); - var gossip_service = try GossipService.init( - gpa_allocator, - contact_info, - my_keypair, - entrypoints, - &exit, - logger, - ); - defer gossip_service.deinit(); - - const spy_node = gossip_spy_node_option.value.bool; - var handle = try std.Thread.spawn( - .{}, - GossipService.run, - .{ &gossip_service, spy_node, gossip_dump_option.value.bool }, - ); - - handle.join(); - metrics_thread.detach(); + return entrypoints; } /// Initializes the global registry. Returns error if registry was already initialized. /// Spawns a thread to serve the metrics over http on the CLI configured port. -/// Uses same allocator for both registry and http adapter. -fn spawnMetrics(allocator: std.mem.Allocator, logger: Logger) !std.Thread { +fn spawnMetrics(logger: Logger) !std.Thread { var metrics_port: u16 = @intCast(metrics_port_option.value.int.?); logger.infof("metrics port: {d}", .{metrics_port}); const registry = globalRegistry(); - return try std.Thread.spawn(.{}, servePrometheus, .{ allocator, registry, metrics_port }); + return try std.Thread.spawn(.{}, servePrometheus, .{ gpa_allocator, registry, metrics_port }); +} + +fn spawnLogger() !Logger { + var logger = Logger.init(gpa_allocator, try enumFromName(Level, log_level_option.value.string.?)); + logger.spawn(); + return logger; } fn accountsDb(_: []const []const u8) !void { @@ -352,7 +525,7 @@ fn accountsDb(_: []const []const u8) !void { // arg parsing const disk_index_path: ?[]const u8 = disk_index_path_option.value.string; const force_unpack_snapshot = force_unpack_snapshot_option.value.bool; - const snapshot_dir = snapshot_dir_option.value.string.?; + const snapshot_dir_str = snapshot_dir_option.value.string.?; const n_cpus = @as(u32, @truncate(try std.Thread.getCpuCount())); var n_threads_snapshot_load: u32 = @intCast(n_threads_snapshot_load_option.value.int.?); @@ -369,7 +542,7 @@ fn accountsDb(_: []const []const u8) !void { const genesis_path = try std.fmt.allocPrint( allocator, "{s}/genesis.bin", - .{snapshot_dir}, + .{snapshot_dir_str}, ); defer allocator.free(genesis_path); @@ -382,7 +555,7 @@ fn accountsDb(_: []const []const u8) !void { const accounts_path = try std.fmt.allocPrint( allocator, "{s}/accounts/", - .{snapshot_dir}, + .{snapshot_dir_str}, ); defer allocator.free(accounts_path); @@ -392,8 +565,9 @@ fn accountsDb(_: []const []const u8) !void { }; const should_unpack_snapshot = !accounts_path_exists or force_unpack_snapshot; - var snapshot_paths = try SnapshotPaths.find(allocator, snapshot_dir); - if (snapshot_paths.incremental_snapshot == null) { + var snapshot_files = try SnapshotFiles.find(allocator, snapshot_dir_str); + defer snapshot_files.deinit(allocator); + if (snapshot_files.incremental_snapshot == null) { logger.infof("no incremental snapshot found", .{}); } @@ -403,31 +577,30 @@ fn accountsDb(_: []const []const u8) !void { if (should_unpack_snapshot) { logger.infof("unpacking snapshots...", .{}); // if accounts/ doesnt exist then we unpack the found snapshots - var snapshot_dir_iter = try std.fs.cwd().openIterableDir(snapshot_dir, .{}); - defer snapshot_dir_iter.close(); + var snapshot_dir = try std.fs.cwd().openDir(snapshot_dir_str, .{}); + defer snapshot_dir.close(); // TODO: delete old accounts/ dir if it exists timer.reset(); - std.debug.print("unpacking {s}...", .{snapshot_paths.full_snapshot.path}); - logger.infof("unpacking {s}...", .{snapshot_paths.full_snapshot.path}); + logger.infof("unpacking {s}...", .{snapshot_files.full_snapshot.filename}); try parallelUnpackZstdTarBall( allocator, - snapshot_paths.full_snapshot.path, - snapshot_dir_iter.dir, + snapshot_files.full_snapshot.filename, + snapshot_dir, n_threads_snapshot_unpack, true, ); logger.infof("unpacked snapshot in {s}", .{std.fmt.fmtDuration(timer.read())}); // TODO: can probs do this in parallel with full snapshot - if (snapshot_paths.incremental_snapshot) |incremental_snapshot| { + if (snapshot_files.incremental_snapshot) |incremental_snapshot| { timer.reset(); - logger.infof("unpacking {s}...", .{incremental_snapshot.path}); + logger.infof("unpacking {s}...", .{incremental_snapshot.filename}); try parallelUnpackZstdTarBall( allocator, - incremental_snapshot.path, - snapshot_dir_iter.dir, + incremental_snapshot.filename, + snapshot_dir, n_threads_snapshot_unpack, false, ); @@ -439,7 +612,7 @@ fn accountsDb(_: []const []const u8) !void { timer.reset(); logger.infof("reading snapshot metadata...", .{}); - var snapshots = try AllSnapshotFields.fromPaths(allocator, snapshot_dir, snapshot_paths); + var snapshots = try AllSnapshotFields.fromFiles(allocator, snapshot_dir_str, snapshot_files); defer { snapshots.all_fields.deinit(allocator); allocator.free(snapshots.full_path); @@ -498,7 +671,7 @@ fn accountsDb(_: []const []const u8) !void { const status_cache_path = try std.fmt.allocPrint( allocator, "{s}/{s}", - .{ snapshot_dir, "snapshots/status_cache" }, + .{ snapshot_dir_str, "snapshots/status_cache" }, ); defer allocator.free(status_cache_path); diff --git a/src/common/lru.zig b/src/common/lru.zig index 0e12f55b3..3ed328291 100644 --- a/src/common/lru.zig +++ b/src/common/lru.zig @@ -10,8 +10,27 @@ pub const Kind = enum { non_locking, }; -// TODO: allow for passing custom hash context to use in std.ArrayHashMap for performance. -pub fn LruCache(comptime kind: Kind, comptime K: type, comptime V: type) type { +pub fn LruCache( + comptime kind: Kind, + comptime K: type, + comptime V: type, +) type { + return LruCacheCustom(kind, K, V, void, struct { + fn noop(_: *V, _: void) void {} + }.noop); +} + +/// LruCache that allows you to specify a custom deinit function +/// to call on a node's data when the node is removed. +/// +/// TODO: allow for passing custom hash context to use in std.ArrayHashMap for performance. +pub fn LruCacheCustom( + comptime kind: Kind, + comptime K: type, + comptime V: type, + comptime DeinitContext: type, + comptime deinitFn: fn (*V, DeinitContext) void, +) type { return struct { mux: if (kind == .locking) Mutex else void, allocator: Allocator, @@ -19,6 +38,7 @@ pub fn LruCache(comptime kind: Kind, comptime K: type, comptime V: type) type { dbl_link_list: TailQueue(LruEntry), max_items: usize, len: usize = 0, + deinit_context: DeinitContext, const Self = @This(); @@ -47,10 +67,21 @@ pub fn LruCache(comptime kind: Kind, comptime K: type, comptime V: type) type { fn deinitNode(self: *Self, node: *Node) void { self.len -= 1; + deinitFn(&node.data.value, self.deinit_context); self.allocator.destroy(node); } - pub fn init(allocator: Allocator, max_items: usize) error{OutOfMemory}!Self { + /// Use if DeinitContext is void. + pub fn init(allocator: Allocator, max_items: usize) error{OutOfMemory}!LruCache(kind, K, V) { + return LruCache(kind, K, V).initWithContext(allocator, max_items, void{}); + } + + /// Use if DeinitContext is not void. + pub fn initWithContext( + allocator: Allocator, + max_items: usize, + deinit_context: DeinitContext, + ) error{OutOfMemory}!Self { var hashmap = if (K == []const u8) std.StringArrayHashMap(*Node).init(allocator) else std.AutoArrayHashMap(K, *Node).init(allocator); var self = Self{ .allocator = allocator, @@ -58,6 +89,7 @@ pub fn LruCache(comptime kind: Kind, comptime K: type, comptime V: type) type { .dbl_link_list = TailQueue(LruEntry){}, .max_items = max_items, .mux = if (kind == .locking) Mutex{} else undefined, + .deinit_context = deinit_context, }; // pre allocate enough capacity for max items since we will use diff --git a/src/core/hard_forks.zig b/src/core/hard_forks.zig index e8da22e2b..b68d5991d 100644 --- a/src/core/hard_forks.zig +++ b/src/core/hard_forks.zig @@ -62,7 +62,11 @@ pub const HardForks = struct { }; test "core.hard_forks: test hardforks" { + const Logger = @import("../trace/log.zig").Logger; const testing_alloc = std.testing.allocator; + var logger = Logger.init(testing_alloc, Logger.TEST_DEFAULT_LEVEL); + defer logger.deinit(); + logger.spawn(); var hard_forks = HardForks.default(testing_alloc); defer hard_forks.deinit(); @@ -78,13 +82,13 @@ test "core.hard_forks: test hardforks" { var hash_data_one = hard_forks.get_hash_data(9, 0); try expect(hash_data_one == null); - std.debug.print("hash_data_one: {any}\n", .{hash_data_one}); + logger.debugf("hash_data_one: {any}", .{hash_data_one}); var hash_data_two = hard_forks.get_hash_data(10, 0); try expect(hash_data_two != null); try expect(std.mem.eql(u8, &hash_data_two.?, &[8]u8{ 1, 0, 0, 0, 0, 0, 0, 0 })); - std.debug.print("hard_forks_two: {any}\n", .{hash_data_two}); + logger.debugf("hard_forks_two: {any}", .{hash_data_two}); try expect(eql(u8, &hard_forks.get_hash_data(19, 0).?, &[8]u8{ 1, 0, 0, 0, 0, 0, 0, 0 })); try expect(eql(u8, &hard_forks.get_hash_data(20, 0).?, &[8]u8{ 2, 0, 0, 0, 0, 0, 0, 0 })); diff --git a/src/core/shred.zig b/src/core/shred.zig index 54086476f..fbd8c00ef 100644 --- a/src/core/shred.zig +++ b/src/core/shred.zig @@ -3,6 +3,8 @@ const std = @import("std"); const HardForks = @import("hard_forks.zig").HardForks; const Allocator = std.mem.Allocator; +pub const Nonce = u32; + pub const ShredVersion = struct { value: u16, @@ -36,22 +38,26 @@ pub const ShredVersion = struct { }; test "core.shred: test ShredVersion" { + const Logger = @import("../trace/log.zig").Logger; var hash = Hash{ .data = [_]u8{ 180, 194, 54, 239, 216, 26, 164, 170, 3, 72, 104, 87, 32, 189, 12, 254, 9, 103, 99, 155, 117, 158, 241, 0, 95, 128, 64, 174, 42, 158, 205, 26 } }; const version = ShredVersion.versionFromHash(&hash); try std.testing.expect(version == 44810); + var logger = Logger.init(std.testing.allocator, Logger.TEST_DEFAULT_LEVEL); + defer logger.deinit(); + logger.spawn(); const testing_alloc = std.testing.allocator; var shred_version_one = try ShredVersion.computeShredVersion(testing_alloc, Hash.default(), null); try std.testing.expect(shred_version_one == 1); - std.debug.print("shred_version_one: {}\n", .{shred_version_one}); + logger.debugf("shred_version_one: {}", .{shred_version_one}); var hard_forks = HardForks.default(testing_alloc); defer _ = hard_forks.deinit(); var shred_version_two = try ShredVersion.computeShredVersion(testing_alloc, Hash.default(), hard_forks); try std.testing.expect(shred_version_two == 1); - std.debug.print("shred_version_two: {}\n", .{shred_version_two}); + logger.debugf("shred_version_two: {}", .{shred_version_two}); try hard_forks.register(1); var shred_version_three = try ShredVersion.computeShredVersion( @@ -60,10 +66,10 @@ test "core.shred: test ShredVersion" { hard_forks, ); try std.testing.expect(shred_version_three == 55551); - std.debug.print("shred_version_three: {}\n", .{shred_version_three}); + logger.debugf("shred_version_three: {}", .{shred_version_three}); try hard_forks.register(1); var shred_version_four = try ShredVersion.computeShredVersion(testing_alloc, Hash.default(), hard_forks); try std.testing.expect(shred_version_four == 46353); - std.debug.print("shred_version_three: {}\n", .{shred_version_four}); + logger.debugf("shred_version_three: {}", .{shred_version_four}); } diff --git a/src/core/signature.zig b/src/core/signature.zig index 0320d51cd..4a88d5ba4 100644 --- a/src/core/signature.zig +++ b/src/core/signature.zig @@ -1,6 +1,8 @@ const std = @import("std"); const Pubkey = @import("pubkey.zig").Pubkey; const Ed25519 = std.crypto.sign.Ed25519; +const Verifier = std.crypto.sign.Ed25519.Verifier; +const e = std.crypto.errors; pub const SIGNATURE_LENGTH: usize = 64; @@ -15,12 +17,20 @@ pub const Signature = struct { }; } - pub fn verify(self: Self, pubkey: Pubkey, msg: []u8) bool { + pub fn verify(self: Self, pubkey: Pubkey, msg: []const u8) bool { const sig = Ed25519.Signature.fromBytes(self.data); sig.verify(msg, Ed25519.PublicKey.fromBytes(pubkey.data) catch unreachable) catch return false; return true; } + pub fn verifier( + self: Self, + pubkey: Pubkey, + ) (e.NonCanonicalError || e.EncodingError || e.IdentityElementError)!Verifier { + const sig = Ed25519.Signature.fromBytes(self.data); + return sig.verifier(Ed25519.PublicKey.fromBytes(pubkey.data) catch unreachable); + } + pub fn eql(self: *const Self, other: *const Self) bool { return std.mem.eql(u8, self.data[0..], other.data[0..]); } diff --git a/src/gossip/active_set.zig b/src/gossip/active_set.zig index 0a172270a..c1372bb2f 100644 --- a/src/gossip/active_set.zig +++ b/src/gossip/active_set.zig @@ -7,7 +7,7 @@ const _gossip_data = @import("../gossip/data.zig"); const SignedGossipData = _gossip_data.SignedGossipData; const getWallclockMs = _gossip_data.getWallclockMs; const ContactInfo = _gossip_data.ContactInfo; -const SOCKET_TAG_GOSSIP = _gossip_data.SOCKET_TAG_GOSSIP; +const socket_tag = _gossip_data.socket_tag; const LegacyContactInfo = _gossip_data.LegacyContactInfo; const Pubkey = @import("../core/pubkey.zig").Pubkey; @@ -107,7 +107,7 @@ pub const ActiveSet = struct { while (iter.next()) |entry| { // lookup peer contact info const peer_info = table.getContactInfo(entry.key_ptr.*) orelse continue; - const peer_gossip_addr = peer_info.getSocket(SOCKET_TAG_GOSSIP) orelse continue; + const peer_gossip_addr = peer_info.getSocket(socket_tag.GOSSIP) orelse continue; peer_gossip_addr.sanitize() catch continue; diff --git a/src/gossip/data.zig b/src/gossip/data.zig index 1f8b2c1c5..6b3f86258 100644 --- a/src/gossip/data.zig +++ b/src/gossip/data.zig @@ -10,7 +10,7 @@ const ArrayList = std.ArrayList; const KeyPair = std.crypto.sign.Ed25519.KeyPair; const Pubkey = @import("../core/pubkey.zig").Pubkey; const sanitizeWallclock = @import("./message.zig").sanitizeWallclock; -const PACKET_DATA_SIZE = @import("./packet.zig").PACKET_DATA_SIZE; +const PACKET_DATA_SIZE = @import("../net/packet.zig").PACKET_DATA_SIZE; const network = @import("zig-network"); const var_int = @import("../utils/varint.zig"); @@ -414,7 +414,7 @@ pub const GossipData = union(enum(u32)) { pub fn gossipAddr(self: *const @This()) ?SocketAddr { return switch (self.*) { .LegacyContactInfo => |*v| if (v.gossip.isUnspecified()) null else v.gossip, - .ContactInfo => |*v| v.getSocket(SOCKET_TAG_GOSSIP), + .ContactInfo => |*v| v.getSocket(socket_tag.GOSSIP), else => null, }; } @@ -501,32 +501,32 @@ pub const LegacyContactInfo = struct { /// call ContactInfo.deinit to free pub fn toContactInfo(self: *const LegacyContactInfo, allocator: std.mem.Allocator) !ContactInfo { var ci = ContactInfo.init(allocator, self.id, self.wallclock, self.shred_version); - try ci.setSocket(SOCKET_TAG_GOSSIP, self.gossip); - try ci.setSocket(SOCKET_TAG_TVU, self.tvu); - try ci.setSocket(SOCKET_TAG_TVU_FORWARDS, self.tvu_forwards); - try ci.setSocket(SOCKET_TAG_REPAIR, self.repair); - try ci.setSocket(SOCKET_TAG_TPU, self.tpu); - try ci.setSocket(SOCKET_TAG_TPU_FORWARDS, self.tpu_forwards); - try ci.setSocket(SOCKET_TAG_TPU_VOTE, self.tpu_vote); - try ci.setSocket(SOCKET_TAG_RPC, self.rpc); - try ci.setSocket(SOCKET_TAG_RPC_PUBSUB, self.rpc_pubsub); - try ci.setSocket(SOCKET_TAG_SERVE_REPAIR, self.serve_repair); + try ci.setSocket(socket_tag.GOSSIP, self.gossip); + try ci.setSocket(socket_tag.TVU, self.tvu); + try ci.setSocket(socket_tag.TVU_FORWARDS, self.tvu_forwards); + try ci.setSocket(socket_tag.REPAIR, self.repair); + try ci.setSocket(socket_tag.TPU, self.tpu); + try ci.setSocket(socket_tag.TPU_FORWARDS, self.tpu_forwards); + try ci.setSocket(socket_tag.TPU_VOTE, self.tpu_vote); + try ci.setSocket(socket_tag.RPC, self.rpc); + try ci.setSocket(socket_tag.RPC_PUBSUB, self.rpc_pubsub); + try ci.setSocket(socket_tag.SERVE_REPAIR, self.serve_repair); return ci; } pub fn fromContactInfo(ci: *const ContactInfo) LegacyContactInfo { return .{ .id = ci.pubkey, - .gossip = ci.getSocket(SOCKET_TAG_GOSSIP) orelse SocketAddr.UNSPECIFIED, - .tvu = ci.getSocket(SOCKET_TAG_TVU) orelse SocketAddr.UNSPECIFIED, - .tvu_forwards = ci.getSocket(SOCKET_TAG_TVU_FORWARDS) orelse SocketAddr.UNSPECIFIED, - .repair = ci.getSocket(SOCKET_TAG_REPAIR) orelse SocketAddr.UNSPECIFIED, - .tpu = ci.getSocket(SOCKET_TAG_TPU) orelse SocketAddr.UNSPECIFIED, - .tpu_forwards = ci.getSocket(SOCKET_TAG_TPU_FORWARDS) orelse SocketAddr.UNSPECIFIED, - .tpu_vote = ci.getSocket(SOCKET_TAG_TPU_VOTE) orelse SocketAddr.UNSPECIFIED, - .rpc = ci.getSocket(SOCKET_TAG_RPC) orelse SocketAddr.UNSPECIFIED, - .rpc_pubsub = ci.getSocket(SOCKET_TAG_RPC_PUBSUB) orelse SocketAddr.UNSPECIFIED, - .serve_repair = ci.getSocket(SOCKET_TAG_SERVE_REPAIR) orelse SocketAddr.UNSPECIFIED, + .gossip = ci.getSocket(socket_tag.GOSSIP) orelse SocketAddr.UNSPECIFIED, + .tvu = ci.getSocket(socket_tag.TVU) orelse SocketAddr.UNSPECIFIED, + .tvu_forwards = ci.getSocket(socket_tag.TVU_FORWARDS) orelse SocketAddr.UNSPECIFIED, + .repair = ci.getSocket(socket_tag.REPAIR) orelse SocketAddr.UNSPECIFIED, + .tpu = ci.getSocket(socket_tag.TPU) orelse SocketAddr.UNSPECIFIED, + .tpu_forwards = ci.getSocket(socket_tag.TPU_FORWARDS) orelse SocketAddr.UNSPECIFIED, + .tpu_vote = ci.getSocket(socket_tag.TPU_VOTE) orelse SocketAddr.UNSPECIFIED, + .rpc = ci.getSocket(socket_tag.RPC) orelse SocketAddr.UNSPECIFIED, + .rpc_pubsub = ci.getSocket(socket_tag.RPC_PUBSUB) orelse SocketAddr.UNSPECIFIED, + .serve_repair = ci.getSocket(socket_tag.SERVE_REPAIR) orelse SocketAddr.UNSPECIFIED, .wallclock = ci.wallclock, .shred_version = ci.shred_version, }; @@ -938,20 +938,22 @@ pub const SnapshotHashes = struct { } }; -pub const SOCKET_TAG_GOSSIP: u8 = 0; -pub const SOCKET_TAG_REPAIR: u8 = 1; -pub const SOCKET_TAG_RPC: u8 = 2; -pub const SOCKET_TAG_RPC_PUBSUB: u8 = 3; -pub const SOCKET_TAG_SERVE_REPAIR: u8 = 4; -pub const SOCKET_TAG_TPU: u8 = 5; -pub const SOCKET_TAG_TPU_FORWARDS: u8 = 6; -pub const SOCKET_TAG_TPU_FORWARDS_QUIC: u8 = 7; -pub const SOCKET_TAG_TPU_QUIC: u8 = 8; -pub const SOCKET_TAG_TPU_VOTE: u8 = 9; -pub const SOCKET_TAG_TVU: u8 = 10; -pub const SOCKET_TAG_TVU_FORWARDS: u8 = 11; -pub const SOCKET_TAG_TVU_QUIC: u8 = 12; -pub const SOCKET_CACHE_SIZE: usize = SOCKET_TAG_TVU_QUIC + 1; +pub const socket_tag = struct { + pub const GOSSIP: u8 = 0; + pub const REPAIR: u8 = 1; + pub const RPC: u8 = 2; + pub const RPC_PUBSUB: u8 = 3; + pub const SERVE_REPAIR: u8 = 4; + pub const TPU: u8 = 5; + pub const TPU_FORWARDS: u8 = 6; + pub const TPU_FORWARDS_QUIC: u8 = 7; + pub const TPU_QUIC: u8 = 8; + pub const TPU_VOTE: u8 = 9; + pub const TVU: u8 = 10; + pub const TVU_FORWARDS: u8 = 11; + pub const TVU_QUIC: u8 = 12; +}; +pub const SOCKET_CACHE_SIZE: usize = socket_tag.TVU_QUIC + 1; pub const ContactInfo = struct { pubkey: Pubkey, @@ -984,7 +986,7 @@ pub const ContactInfo = struct { pub fn initSpy(allocator: std.mem.Allocator, id: Pubkey, gossip_socket_addr: SocketAddr, shred_version: u16) !Self { var contact_info = Self.init(allocator, id, @intCast(std.time.microTimestamp()), shred_version); - try contact_info.setSocket(SOCKET_TAG_GOSSIP, gossip_socket_addr); + try contact_info.setSocket(socket_tag.GOSSIP, gossip_socket_addr); return contact_info; } @@ -1254,12 +1256,12 @@ test "gossip.data: set & get socket on contact info" { var ci = ContactInfo.init(testing.allocator, Pubkey.random(rng), @as(u64, @intCast(std.time.microTimestamp())), 0); defer ci.deinit(); - try ci.setSocket(SOCKET_TAG_RPC, SocketAddr.initIpv4(.{ 127, 0, 0, 1 }, 8899)); + try ci.setSocket(socket_tag.RPC, SocketAddr.initIpv4(.{ 127, 0, 0, 1 }, 8899)); - var set_socket = ci.getSocket(SOCKET_TAG_RPC); + var set_socket = ci.getSocket(socket_tag.RPC); try testing.expect(set_socket.?.eql(&SocketAddr.initIpv4(.{ 127, 0, 0, 1 }, 8899))); try testing.expect(ci.addrs.items[0].eql(&IpAddr.newIpv4(127, 0, 0, 1))); - try testing.expect(ci.sockets.items[0].eql(&SocketEntry.init(SOCKET_TAG_RPC, 0, 8899))); + try testing.expect(ci.sockets.items[0].eql(&SocketEntry.init(socket_tag.RPC, 0, 8899))); } test "gossip.data: contact info bincode serialize matches rust bincode" { diff --git a/src/gossip/fuzz.zig b/src/gossip/fuzz.zig index 405a877fb..7aee9029f 100644 --- a/src/gossip/fuzz.zig +++ b/src/gossip/fuzz.zig @@ -15,7 +15,7 @@ const _gossip_data = @import("data.zig"); const LegacyContactInfo = _gossip_data.LegacyContactInfo; const SignedGossipData = _gossip_data.SignedGossipData; const ContactInfo = _gossip_data.ContactInfo; -const SOCKET_TAG_GOSSIP = _gossip_data.SOCKET_TAG_GOSSIP; +const socket_tag = _gossip_data.socket_tag; const AtomicBool = std.atomic.Atomic(bool); const SocketAddr = @import("../net/net.zig").SocketAddr; @@ -26,8 +26,8 @@ const getWallclockMs = @import("data.zig").getWallclockMs; const Bloom = @import("../bloom/bloom.zig").Bloom; const network = @import("zig-network"); const EndPoint = network.EndPoint; -const Packet = @import("packet.zig").Packet; -const PACKET_DATA_SIZE = @import("packet.zig").PACKET_DATA_SIZE; +const Packet = @import("../net/packet.zig").Packet; +const PACKET_DATA_SIZE = @import("../net/packet.zig").PACKET_DATA_SIZE; const NonBlockingChannel = @import("../sync/channel.zig").NonBlockingChannel; const Thread = std.Thread; @@ -278,7 +278,7 @@ pub fn main() !void { var fuzz_pubkey = Pubkey.fromPublicKey(&fuzz_keypair.public_key); var fuzz_contact_info = ContactInfo.init(allocator, fuzz_pubkey, 0, 19); - try fuzz_contact_info.setSocket(SOCKET_TAG_GOSSIP, fuzz_address); + try fuzz_contact_info.setSocket(socket_tag.GOSSIP, fuzz_address); var fuzz_exit = AtomicBool.init(false); var gossip_service_fuzzer = try GossipService.init( diff --git a/src/gossip/message.zig b/src/gossip/message.zig index b0d1fd4c3..78fb6ceb3 100644 --- a/src/gossip/message.zig +++ b/src/gossip/message.zig @@ -11,7 +11,7 @@ const LegacyContactInfo = _gossip_data.LegacyContactInfo; const getWallclockMs = _gossip_data.getWallclockMs; const GossipPullFilter = @import("pull_request.zig").GossipPullFilter; -const PACKET_DATA_SIZE = @import("packet.zig").PACKET_DATA_SIZE; +const PACKET_DATA_SIZE = @import("../net/packet.zig").PACKET_DATA_SIZE; const DefaultPrng = std.rand.DefaultPrng; const KeyPair = std.crypto.sign.Ed25519.KeyPair; @@ -231,7 +231,7 @@ test "gossip.message: push message serialization is predictable" { msg_with_value, bincode.Params{}, ); - std.debug.print("value_size, empty_size, msg_value_size: {d} {d} {d}\n", .{ value_size, empty_size, msg_value_size }); + // std.debug.print("value_size, empty_size, msg_value_size: {d} {d} {d}\n", .{ value_size, empty_size, msg_value_size }); try std.testing.expectEqual(value_size + empty_size, msg_value_size); } diff --git a/src/gossip/ping_pong.zig b/src/gossip/ping_pong.zig index a9923bc55..4b2a8f1ef 100644 --- a/src/gossip/ping_pong.zig +++ b/src/gossip/ping_pong.zig @@ -6,7 +6,7 @@ const _gossip_data = @import("data.zig"); const SignedGossipData = _gossip_data.SignedGossipData; const GossipData = _gossip_data.GossipData; const ContactInfo = _gossip_data.ContactInfo; -const SOCKET_TAG_GOSSIP = _gossip_data.SOCKET_TAG_GOSSIP; +const socket_tag = _gossip_data.socket_tag; const getWallclockMs = _gossip_data.getWallclockMs; const DefaultPrng = std.rand.DefaultPrng; @@ -53,7 +53,7 @@ pub const Ping = struct { }; } - pub fn verify(self: *Self) !void { + pub fn verify(self: *const Self) !void { if (!self.signature.verify(self.from, &self.token)) { return error.InvalidSignature; } @@ -79,7 +79,7 @@ pub const Pong = struct { }; } - pub fn verify(self: *Self) !void { + pub fn verify(self: *const Self) error{InvalidSignature}!void { if (!self.signature.verify(self.from, &self.hash.data)) { return error.InvalidSignature; } @@ -89,6 +89,12 @@ pub const Pong = struct { const ping = try Ping.random(rng, keypair); return try Pong.init(&ping, keypair); } + + pub fn eql(self: *const @This(), other: *const @This()) bool { + return std.mem.eql(u8, &self.from.data, &other.from.data) and + std.mem.eql(u8, &self.hash.data, &other.hash.data) and + std.mem.eql(u8, &self.signature.data, &other.signature.data); + } }; /// `PubkeyAndSocketAddr` is a 2 element tuple: `.{ Pubkey, SocketAddr }` @@ -218,7 +224,7 @@ pub const PingCache = struct { var pings = std.ArrayList(PingAndSocketAddr).init(allocator); for (peers, 0..) |*peer, i| { - if (peer.getSocket(SOCKET_TAG_GOSSIP)) |gossip_addr| { + if (peer.getSocket(socket_tag.GOSSIP)) |gossip_addr| { var result = self.check(now, PubkeyAndSocketAddr{ .pubkey = peer.pubkey, .socket_addr = gossip_addr }, &our_keypair); if (result.passes_ping_check) { try valid_peers.append(i); diff --git a/src/gossip/service.zig b/src/gossip/service.zig index cd1981048..9b6deef39 100644 --- a/src/gossip/service.zig +++ b/src/gossip/service.zig @@ -1,8 +1,8 @@ const std = @import("std"); const network = @import("zig-network"); const EndPoint = network.EndPoint; -const Packet = @import("packet.zig").Packet; -const PACKET_DATA_SIZE = @import("packet.zig").PACKET_DATA_SIZE; +const Packet = @import("../net/packet.zig").Packet; +const PACKET_DATA_SIZE = @import("../net/packet.zig").PACKET_DATA_SIZE; const ThreadPool = @import("../sync/thread_pool.zig").ThreadPool; const Task = ThreadPool.Task; const Batch = ThreadPool.Batch; @@ -24,7 +24,7 @@ const bincode = @import("../bincode/bincode.zig"); const gossip = @import("../gossip/data.zig"); const LegacyContactInfo = gossip.LegacyContactInfo; const ContactInfo = @import("data.zig").ContactInfo; -const SOCKET_TAG_GOSSIP = @import("data.zig").SOCKET_TAG_GOSSIP; +const socket_tag = @import("data.zig").socket_tag; const SignedGossipData = gossip.SignedGossipData; const KeyPair = std.crypto.sign.Ed25519.KeyPair; const Pubkey = @import("../core/pubkey.zig").Pubkey; @@ -278,7 +278,7 @@ pub const GossipService = struct { var active_set = ActiveSet.init(allocator); // bind the socket - const gossip_address = my_contact_info.getSocket(SOCKET_TAG_GOSSIP) orelse return error.GossipAddrUnspecified; + const gossip_address = my_contact_info.getSocket(socket_tag.GOSSIP) orelse return error.GossipAddrUnspecified; var gossip_socket = UdpSocket.create(.ipv4, .udp) catch return error.SocketCreateFailed; gossip_socket.bindToPort(gossip_address.port()) catch return error.SocketBindFailed; gossip_socket.setReadTimeout(socket_utils.SOCKET_TIMEOUT) catch return error.SocketSetTimeoutFailed; // 1 second @@ -930,7 +930,7 @@ pub const GossipService = struct { // update wallclock and sign self.my_contact_info.wallclock = getWallclockMs(); var my_contact_info_value = try gossip.SignedGossipData.initSigned(gossip.GossipData{ - .ContactInfo = self.my_contact_info, + .ContactInfo = try self.my_contact_info.clone(), }, &self.my_keypair); var my_legacy_contact_info_value = try gossip.SignedGossipData.initSigned(gossip.GossipData{ .LegacyContactInfo = LegacyContactInfo.fromContactInfo(&self.my_contact_info), @@ -1234,7 +1234,7 @@ pub const GossipService = struct { const peer_index = rng.random().intRangeAtMost(usize, 0, num_peers - 1); const peer_contact_info_index = valid_gossip_peer_indexs.items[peer_index]; const peer_contact_info = peers[peer_contact_info_index]; - if (peer_contact_info.getSocket(SOCKET_TAG_GOSSIP)) |gossip_addr| { + if (peer_contact_info.getSocket(socket_tag.GOSSIP)) |gossip_addr| { const message = GossipMessage{ .PullRequest = .{ filter_i, my_contact_info_value } }; var packet = &packet_batch.items[packet_index]; @@ -1609,7 +1609,7 @@ pub const GossipService = struct { return error.CantFindContactInfo; }; }; - const from_gossip_addr = from_contact_info.getSocket(SOCKET_TAG_GOSSIP) orelse return error.InvalidGossipAddress; + const from_gossip_addr = from_contact_info.getSocket(socket_tag.GOSSIP) orelse return error.InvalidGossipAddress; gossip.sanitizeSocket(&from_gossip_addr) catch return error.InvalidGossipAddress; const from_gossip_endpoint = from_gossip_addr.toEndpoint(); @@ -1730,7 +1730,7 @@ pub const GossipService = struct { // unable to find contact info continue; }; - const from_gossip_addr = from_contact_info.getSocket(SOCKET_TAG_GOSSIP) orelse continue; + const from_gossip_addr = from_contact_info.getSocket(socket_tag.GOSSIP) orelse continue; from_gossip_addr.sanitize() catch { // invalid gossip socket continue; @@ -1965,7 +1965,7 @@ pub const GossipService = struct { var node_index: usize = 0; for (contact_infos) |contact_info| { - const peer_gossip_addr = contact_info.getSocket(SOCKET_TAG_GOSSIP); + const peer_gossip_addr = contact_info.getSocket(socket_tag.GOSSIP); // filter self if (contact_info.pubkey.equals(&self.my_pubkey)) { @@ -2118,7 +2118,7 @@ test "gossip.gossip_service: build messages startup and shutdown" { var my_pubkey = Pubkey.fromPublicKey(&my_keypair.public_key); const contact_info = try localhostTestContactInfo(my_pubkey); - var logger = Logger.init(std.testing.allocator, .debug); + var logger = Logger.init(std.testing.allocator, Logger.TEST_DEFAULT_LEVEL); defer logger.deinit(); logger.spawn(); @@ -2171,7 +2171,7 @@ test "gossip.gossip_service: tests handling prune messages" { var my_pubkey = Pubkey.fromPublicKey(&my_keypair.public_key); const contact_info = try localhostTestContactInfo(my_pubkey); - var logger = Logger.init(std.testing.allocator, .debug); + var logger = Logger.init(std.testing.allocator, Logger.TEST_DEFAULT_LEVEL); defer logger.deinit(); logger.spawn(); @@ -2245,7 +2245,7 @@ test "gossip.gossip_service: tests handling pull responses" { var my_pubkey = Pubkey.fromPublicKey(&my_keypair.public_key); const contact_info = try localhostTestContactInfo(my_pubkey); - var logger = Logger.init(std.testing.allocator, .debug); + var logger = Logger.init(std.testing.allocator, Logger.TEST_DEFAULT_LEVEL); defer logger.deinit(); logger.spawn(); @@ -2304,7 +2304,7 @@ test "gossip.gossip_service: tests handle pull request" { var my_pubkey = Pubkey.fromPublicKey(&my_keypair.public_key); const contact_info = try localhostTestContactInfo(my_pubkey); - var logger = Logger.init(std.testing.allocator, .debug); + var logger = Logger.init(std.testing.allocator, Logger.TEST_DEFAULT_LEVEL); defer logger.deinit(); logger.spawn(); @@ -2374,7 +2374,7 @@ test "gossip.gossip_service: tests handle pull request" { defer pull_requests.deinit(); try pull_requests.append(GossipService.PullRequestMessage{ .filter = filter, - .from_endpoint = (contact_info.getSocket(SOCKET_TAG_GOSSIP) orelse unreachable).toEndpoint(), + .from_endpoint = (contact_info.getSocket(socket_tag.GOSSIP) orelse unreachable).toEndpoint(), .value = gossip_value, }); @@ -2395,7 +2395,7 @@ test "gossip.gossip_service: test build prune messages and handle push messages" var my_pubkey = Pubkey.fromPublicKey(&my_keypair.public_key); const contact_info = try localhostTestContactInfo(my_pubkey); - var logger = Logger.init(std.testing.allocator, .debug); + var logger = Logger.init(std.testing.allocator, Logger.TEST_DEFAULT_LEVEL); defer logger.deinit(); logger.spawn(); @@ -2482,7 +2482,7 @@ test "gossip.gossip_service: test build pull requests" { var my_pubkey = Pubkey.fromPublicKey(&my_keypair.public_key); const contact_info = try localhostTestContactInfo(my_pubkey); - var logger = Logger.init(std.testing.allocator, .debug); + var logger = Logger.init(std.testing.allocator, Logger.TEST_DEFAULT_LEVEL); defer logger.deinit(); logger.spawn(); @@ -2524,7 +2524,7 @@ test "gossip.gossip_service: test build push messages" { var my_pubkey = Pubkey.fromPublicKey(&my_keypair.public_key); const contact_info = try localhostTestContactInfo(my_pubkey); - var logger = Logger.init(std.testing.allocator, .debug); + var logger = Logger.init(std.testing.allocator, Logger.TEST_DEFAULT_LEVEL); defer logger.deinit(); logger.spawn(); @@ -2597,7 +2597,7 @@ test "gossip.gossip_service: test packet verification" { var id = Pubkey.fromPublicKey(&keypair.public_key); const contact_info = try localhostTestContactInfo(id); - var logger = Logger.init(std.testing.allocator, .debug); + var logger = Logger.init(std.testing.allocator, Logger.TEST_DEFAULT_LEVEL); defer logger.deinit(); logger.spawn(); @@ -2727,7 +2727,7 @@ test "gossip.gossip_service: process contact info push packet" { var my_pubkey = Pubkey.fromPublicKey(&my_keypair.public_key); const contact_info = try localhostTestContactInfo(my_pubkey); - var logger = Logger.init(std.testing.allocator, .debug); + var logger = Logger.init(std.testing.allocator, Logger.TEST_DEFAULT_LEVEL); defer logger.deinit(); logger.spawn(); @@ -2814,9 +2814,9 @@ test "gossip.gossip_service: init, exit, and deinit" { var my_keypair = try KeyPair.create(null); var rng = std.rand.DefaultPrng.init(getWallclockMs()); var contact_info = try LegacyContactInfo.random(rng.random()).toContactInfo(std.testing.allocator); - try contact_info.setSocket(SOCKET_TAG_GOSSIP, gossip_address); + try contact_info.setSocket(socket_tag.GOSSIP, gossip_address); var exit = AtomicBool.init(false); - var logger = Logger.init(std.testing.allocator, .debug); + var logger = Logger.init(std.testing.allocator, Logger.TEST_DEFAULT_LEVEL); defer logger.deinit(); logger.spawn(); @@ -2872,7 +2872,7 @@ pub const BenchmarkGossipServiceGeneral = struct { var pubkey = Pubkey.fromPublicKey(&keypair.public_key); var contact_info = ContactInfo.init(allocator, pubkey, 0, 19); - try contact_info.setSocket(SOCKET_TAG_GOSSIP, address); + try contact_info.setSocket(socket_tag.GOSSIP, address); // var logger = Logger.init(allocator, .debug); // defer logger.deinit(); @@ -2979,6 +2979,6 @@ pub const BenchmarkGossipServiceGeneral = struct { fn localhostTestContactInfo(id: Pubkey) !ContactInfo { var contact_info = try LegacyContactInfo.default(id).toContactInfo(std.testing.allocator); - try contact_info.setSocket(SOCKET_TAG_GOSSIP, SocketAddr.initIpv4(.{ 127, 0, 0, 1 }, 0)); + try contact_info.setSocket(socket_tag.GOSSIP, SocketAddr.initIpv4(.{ 127, 0, 0, 1 }, 0)); return contact_info; } diff --git a/src/gossip/table.zig b/src/gossip/table.zig index b38836989..5d3befa7b 100644 --- a/src/gossip/table.zig +++ b/src/gossip/table.zig @@ -12,7 +12,7 @@ const GossipVersionedData = _gossip_data.GossipVersionedData; const GossipKey = _gossip_data.GossipKey; const LegacyContactInfo = _gossip_data.LegacyContactInfo; const ContactInfo = _gossip_data.ContactInfo; -const SOCKET_TAG_GOSSIP = _gossip_data.SOCKET_TAG_GOSSIP; +const socket_tag = _gossip_data.socket_tag; const getWallclockMs = _gossip_data.getWallclockMs; const Vote = _gossip_data.Vote; @@ -26,7 +26,7 @@ const KeyPair = std.crypto.sign.Ed25519.KeyPair; const RwLock = std.Thread.RwLock; const SocketAddr = @import("../net/net.zig").SocketAddr; -const PACKET_DATA_SIZE = @import("./packet.zig").PACKET_DATA_SIZE; +const PACKET_DATA_SIZE = @import("../net/packet.zig").PACKET_DATA_SIZE; pub const UNIQUE_PUBKEY_CAPACITY: usize = 8192; pub const MAX_TABLE_SIZE: usize = 1_000_000; // TODO: better value for this @@ -127,7 +127,6 @@ pub const GossipTable = struct { } pub fn deinit(self: *Self) void { - self.store.deinit(); self.contact_infos.deinit(); self.shred_versions.deinit(); self.votes.deinit(); @@ -148,6 +147,12 @@ pub const GossipTable = struct { entry.value_ptr.deinit(); } self.converted_contact_infos.deinit(); + + var store_iter = self.store.iterator(); + while (store_iter.next()) |entry| { + bincode.free(self.allocator, entry.value_ptr.value.data); + } + self.store.deinit(); } pub fn insert(self: *Self, value: SignedGossipData, now: u64) !void { @@ -484,32 +489,65 @@ pub const GossipTable = struct { ); } + /// Returns a slice of contact infos that are no older than minimum_insertion_timestamp. + /// You must provide a buffer to fill with the contact infos. If you want all contact + /// infos, the buffer should be at least `self.contact_infos.count()` in size. pub fn getContactInfos( self: *const Self, buf: []ContactInfo, minimum_insertion_timestamp: u64, ) []ContactInfo { - const store_values = self.store.values(); - const contact_indexs = self.contact_infos.iterator().keys; - - var tgt_idx: usize = 0; - for (0..self.contact_infos.count()) |src_idx| { - if (tgt_idx >= buf.len) break; - const index = contact_indexs[src_idx]; - const entry = store_values[index]; - if (entry.timestamp_on_insertion >= minimum_insertion_timestamp) { - const contact_info = switch (entry.value.data) { - .LegacyContactInfo => |lci| self.converted_contact_infos.get(lci.id) orelse unreachable, - .ContactInfo => |ci| ci, - else => unreachable, - }; - buf[tgt_idx] = contact_info; - tgt_idx += 1; - } + var infos = self.contactInfoIterator(minimum_insertion_timestamp); + var i: usize = 0; + while (infos.next()) |info| { + if (i >= buf.len) break; + buf[i] = info.*; + i += 1; } - return buf[0..tgt_idx]; + return buf[0..i]; + } + + /// Similar to getContactInfos, but returns an iterator instead + /// of a slice. This allows you to avoid an allocation and avoid + /// copying every value. + pub fn contactInfoIterator( + self: *const Self, + minimum_insertion_timestamp: u64, + ) ContactInfoIterator { + return .{ + .values = self.store.values(), + .converted_contact_infos = &self.converted_contact_infos, + .indices = self.contact_infos.iterator().keys, + .count = self.contact_infos.count(), + .minimum_insertion_timestamp = minimum_insertion_timestamp, + }; } + pub const ContactInfoIterator = struct { + values: []const GossipVersionedData, + converted_contact_infos: *const AutoArrayHashMap(Pubkey, ContactInfo), + indices: [*]usize, + count: usize, + minimum_insertion_timestamp: u64, + index_cursor: usize = 0, + + pub fn next(self: *@This()) ?*const ContactInfo { + while (self.index_cursor < self.count) { + const index = self.indices[self.index_cursor]; + self.index_cursor += 1; + const value = &self.values[index]; + if (value.timestamp_on_insertion >= self.minimum_insertion_timestamp) { + return switch (value.value.data) { + .LegacyContactInfo => |*lci| self.converted_contact_infos.getPtr(lci.id) orelse unreachable, + .ContactInfo => |*ci| ci, + else => unreachable, + }; + } + } + return null; + } + }; + // ** shard getter fcns ** pub fn getBitmaskMatches( self: *const Self, @@ -822,7 +860,7 @@ pub const GossipTable = struct { for (contact_indexs) |index| { const entry: GossipVersionedData = self.store.values()[index]; switch (entry.value.data) { - .ContactInfo => |ci| if (ci.getSocket(SOCKET_TAG_GOSSIP)) |addr| { + .ContactInfo => |ci| if (ci.getSocket(socket_tag.GOSSIP)) |addr| { if (addr.eql(&gossip_addr)) return try ci.clone(); }, .LegacyContactInfo => |lci| if (lci.gossip.eql(&gossip_addr)) { diff --git a/src/lib.zig b/src/lib.zig index c17dbb0ae..2c4382bf8 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -31,7 +31,6 @@ pub const gossip = struct { pub usingnamespace @import("gossip/data.zig"); pub usingnamespace @import("gossip/table.zig"); pub usingnamespace @import("gossip/service.zig"); - pub usingnamespace @import("gossip/packet.zig"); pub usingnamespace @import("gossip/message.zig"); pub usingnamespace @import("gossip/pull_request.zig"); pub usingnamespace @import("gossip/pull_response.zig"); @@ -66,6 +65,7 @@ pub const utils = struct { }; pub const trace = struct { + pub usingnamespace @import("trace/level.zig"); pub usingnamespace @import("trace/log.zig"); pub usingnamespace @import("trace/entry.zig"); }; @@ -86,6 +86,7 @@ pub const cmd = struct { pub const net = struct { pub usingnamespace @import("net/net.zig"); pub usingnamespace @import("net/echo.zig"); + pub usingnamespace @import("net/packet.zig"); pub usingnamespace @import("net/socket_utils.zig"); }; @@ -98,3 +99,9 @@ pub const prometheus = struct { pub usingnamespace @import("prometheus/metric.zig"); pub usingnamespace @import("prometheus/registry.zig"); }; + +pub const tvu = struct { + pub usingnamespace @import("tvu/repair_message.zig"); + pub usingnamespace @import("tvu/repair_service.zig"); + pub usingnamespace @import("tvu/shred_receiver.zig"); +}; diff --git a/src/net/echo.zig b/src/net/echo.zig index 132037a2f..705772b04 100644 --- a/src/net/echo.zig +++ b/src/net/echo.zig @@ -317,7 +317,7 @@ test "net.echo: Server works" { const port: u16 = 34333; // initialize logger - var logger = Logger.init(testing.allocator, .info); + var logger = Logger.init(testing.allocator, Logger.TEST_DEFAULT_LEVEL); defer logger.deinit(); logger.spawn(); diff --git a/src/gossip/packet.zig b/src/net/packet.zig similarity index 100% rename from src/gossip/packet.zig rename to src/net/packet.zig diff --git a/src/net/socket_utils.zig b/src/net/socket_utils.zig index b88ac6196..10ab67b46 100644 --- a/src/net/socket_utils.zig +++ b/src/net/socket_utils.zig @@ -1,6 +1,8 @@ +const Allocator = std.mem.Allocator; +const Atomic = std.atomic.Atomic; const UdpSocket = @import("zig-network").Socket; -const Packet = @import("../gossip/packet.zig").Packet; -const PACKET_DATA_SIZE = @import("../gossip/packet.zig").PACKET_DATA_SIZE; +const Packet = @import("packet.zig").Packet; +const PACKET_DATA_SIZE = @import("packet.zig").PACKET_DATA_SIZE; const Channel = @import("../sync/channel.zig").Channel; const std = @import("std"); const Logger = @import("../trace/log.zig").Logger; @@ -142,6 +144,43 @@ pub fn sendSocket( logger.debugf("sendSocket loop closed", .{}); } +/// A thread that is dedicated to either sending or receiving data over a socket. +/// The included channel can be used communicate with that thread. +/// +/// The channel only supports one: either sending or receiving, depending how it +/// was initialized. While you *could* send data to the channel for a "receiver" +/// socket, the underlying thread won't actually read the data from the channel. +pub const SocketThread = struct { + channel: *Channel(std.ArrayList(Packet)), + exit: *std.atomic.Atomic(bool), + handle: std.Thread, + + const Self = @This(); + + pub fn initSender(allocator: Allocator, logger: Logger, socket: *UdpSocket, exit: *Atomic(bool)) !Self { + const channel = Channel(std.ArrayList(Packet)).init(allocator, 0); + return .{ + .channel = channel, + .exit = exit, + .handle = try std.Thread.spawn(.{}, sendSocket, .{ socket, channel, exit, logger }), + }; + } + + pub fn initReceiver(allocator: Allocator, logger: Logger, socket: *UdpSocket, exit: *Atomic(bool)) !Self { + const channel = Channel(std.ArrayList(Packet)).init(allocator, 0); + return .{ + .channel = channel, + .exit = exit, + .handle = try std.Thread.spawn(.{}, readSocket, .{ allocator, socket, channel, exit, logger }), + }; + } + + pub fn deinit(self: Self) void { + self.exit.store(true, .Unordered); + self.handle.join(); + } +}; + pub const BenchmarkPacketProcessing = struct { pub const min_iterations = 3; pub const max_iterations = 5; diff --git a/src/sync/channel.zig b/src/sync/channel.zig index c64fe573e..e9a9fcce2 100644 --- a/src/sync/channel.zig +++ b/src/sync/channel.zig @@ -163,7 +163,7 @@ fn testSender(chan: *BlockChannel, total_send: usize) void { chan.close(); } -const Packet = @import("../gossip/packet.zig").Packet; +const Packet = @import("../net/packet.zig").Packet; fn testPacketSender(chan: *Channel(Packet), total_send: usize) void { var i: usize = 0; while (i < total_send) : (i += 1) { diff --git a/src/trace/entry.zig b/src/trace/entry.zig index 9139a8331..bd107bf2b 100644 --- a/src/trace/entry.zig +++ b/src/trace/entry.zig @@ -216,5 +216,5 @@ test "trace.entry: should info log correctly" { .field("possible_value", anull) .logf(.info, "hello, {s}", .{"world!"}); - std.debug.print("{any}\n\n", .{logger}); + // std.debug.print("{any}\n\n", .{logger}); } diff --git a/src/trace/log.zig b/src/trace/log.zig index 5187eee37..9b3503fe8 100644 --- a/src/trace/log.zig +++ b/src/trace/log.zig @@ -121,6 +121,9 @@ pub const Logger = union(enum) { .noop => {}, } } + + /// Can be used in tests to minimize the amount of logging during tests. + pub const TEST_DEFAULT_LEVEL: Level = .warn; }; pub const StandardErrLogger = struct { @@ -248,7 +251,8 @@ pub const StdErrSink = struct { }; test "trace.logger: works" { - var logger = Logger.init(testing.allocator, .info); + var logger: Logger = .noop; // uncomment below to run visual test + // var logger = Logger.init(testing.allocator, .info); logger.spawn(); defer logger.deinit(); diff --git a/src/tvu/repair_message.zig b/src/tvu/repair_message.zig new file mode 100644 index 000000000..b4a16c31a --- /dev/null +++ b/src/tvu/repair_message.zig @@ -0,0 +1,479 @@ +const std = @import("std"); +const sig = @import("../lib.zig"); + +const bincode = sig.bincode; + +const Allocator = std.mem.Allocator; +const KeyPair = std.crypto.sign.Ed25519.KeyPair; + +const LegacyContactInfo = sig.gossip.LegacyContactInfo; +const Nonce = sig.core.Nonce; +const Pong = sig.gossip.Pong; +const Pubkey = sig.core.Pubkey; +const Signature = sig.core.Signature; +const Slot = sig.core.Slot; + +const SIGNATURE_LENGTH = sig.core.SIGNATURE_LENGTH; + +/// Analogous to `SIGNED_REPAIR_TIME_WINDOW` +const SIGNED_REPAIR_TIME_WINDOW_SECS: u64 = 600; + +/// Internal representation of a repair request. +/// Does not contain any header or identification, only info about the desired shreds. +/// +/// Analogous to `solana_core::repair::serve_repair::ShredRepairType` +pub const RepairRequest = union(enum) { + /// Requesting `MAX_ORPHAN_REPAIR_RESPONSES` parent shreds + Orphan: Slot, + /// Requesting any shred with index greater than or equal to the particular index + HighestShred: struct { Slot, u64 }, + /// Requesting the missing shred at a particular index + Shred: struct { Slot, u64 }, +}; + +/// Executes all three because they are tightly coupled: +/// - convert request to message +/// - serialize message +/// - sign message +/// +/// Analogous to `ServeRepair::map_repair_request` +pub fn serializeRepairRequest( + allocator: Allocator, + request: RepairRequest, + keypair: *const KeyPair, + recipient: Pubkey, + timestamp: u64, + nonce: Nonce, +) ![]u8 { + const header = RepairRequestHeader{ + .signature = Signature.init(undefined), + .sender = try Pubkey.fromBytes(&keypair.public_key.bytes), + .recipient = recipient, + .timestamp = timestamp, + .nonce = nonce, + }; + var msg: RepairMessage = switch (request) { + .Shred => |r| .{ .WindowIndex = .{ + .header = header, + .slot = r[0], + .shred_index = r[1], + } }, + .HighestShred => |r| .{ .HighestWindowIndex = .{ + .header = header, + .slot = r[0], + .shred_index = r[1], + } }, + .Orphan => |r| .{ .Orphan = .{ + .header = header, + .slot = r, + } }, + }; + var buf = try allocator.alloc(u8, RepairMessage.MAX_SERIALIZED_SIZE); + var stream = std.io.fixedBufferStream(buf); + try bincode.write(null, stream.writer(), msg, .{}); + var serialized = try allocator.realloc(buf, stream.pos); + + var signer = try keypair.signer(null); // TODO noise + signer.update(serialized[0..4]); + signer.update(serialized[4 + SIGNATURE_LENGTH ..]); + @memcpy(serialized[4 .. 4 + SIGNATURE_LENGTH], &signer.finalize().toBytes()); + + return serialized; +} + +/// Messaging data that is directly serialized and sent over repair sockets. +/// Contains any header/identification as needed. +/// +/// Analogous to `solana_core::repair::serve_repair::RepairProtocol` +pub const RepairMessage = union(enum(u8)) { + Pong: Pong = 7, + WindowIndex: struct { + header: RepairRequestHeader, + slot: Slot, + shred_index: u64, + }, + HighestWindowIndex: struct { + header: RepairRequestHeader, + slot: Slot, + shred_index: u64, + }, + Orphan: struct { + header: RepairRequestHeader, + slot: Slot, + }, + AncestorHashes: struct { + header: RepairRequestHeader, + slot: Slot, + }, + + pub const Tag: type = @typeInfo(@This()).Union.tag_type.?; + + const MAX_SERIALIZED_SIZE: usize = 160; + + pub fn eql(self: *const @This(), other: *const @This()) bool { + if (!std.mem.eql(u8, @tagName(self.*), @tagName(other.*))) { + return false; + } + switch (self.*) { + .Pong => |*s| return s.eql(&other.Pong), + .WindowIndex => |*s| { + const o = other.WindowIndex; + return s.header.eql(&o.header) and s.slot == o.slot and s.shred_index == o.shred_index; + }, + .HighestWindowIndex => |*s| { + const o = other.HighestWindowIndex; + return s.header.eql(&o.header) and s.slot == o.slot and s.shred_index == o.shred_index; + }, + .Orphan => |*s| { + return s.header.eql(&other.Orphan.header) and s.slot == other.Orphan.slot; + }, + .AncestorHashes => |*s| { + return s.header.eql(&other.AncestorHashes.header) and s.slot == other.AncestorHashes.slot; + }, + } + } + + /// Analogous to `ServeRepair::verify_signed_packet` + pub fn verify( + self: *const @This(), + /// bincode serialized data, from which this struct was deserialized + serialized: []u8, + /// to compare to the header. typically is this validator's own pubkey + expected_recipient: Pubkey, + /// unix timestamp in milliseconds when this function is called + current_timestamp_millis: u64, + ) error{ IdMismatch, InvalidSignature, Malformed, TimeSkew }!void { + switch (self.*) { + .Pong => |pong| try pong.verify(), + inline else => |msg| { + // i am the intended recipient + const header: RepairRequestHeader = msg.header; + if (!header.recipient.equals(&expected_recipient)) { + return error.IdMismatch; + } + + // message was generated recently + const time_diff = @as(i128, current_timestamp_millis) - @as(i128, header.timestamp); + const time_diff_abs = std.math.absInt(time_diff) catch unreachable; + if (time_diff_abs > SIGNED_REPAIR_TIME_WINDOW_SECS) { + return error.TimeSkew; + } + + // signature is valid + if (serialized.len < 4 + SIGNATURE_LENGTH) { + return error.Malformed; + } + var verifier = header.signature.verifier(header.sender) catch { + return error.InvalidSignature; + }; + verifier.update(serialized[0..4]); + verifier.update(serialized[4 + SIGNATURE_LENGTH ..]); + verifier.verify() catch { + return error.InvalidSignature; + }; + }, + } + } +}; + +pub const RepairRequestHeader = struct { + signature: Signature, + sender: Pubkey, + recipient: Pubkey, + timestamp: u64, + nonce: Nonce, + + fn eql(self: *const @This(), other: *const @This()) bool { + return self.signature.eql(&other.signature) and + self.sender.equals(&other.sender) and + self.recipient.equals(&other.recipient) and + self.timestamp == other.timestamp and + self.nonce == other.nonce; + } +}; + +test "tvu.repair_message: signed/serialized RepairRequest is valid" { + const allocator = std.testing.allocator; + var rand = std.rand.DefaultPrng.init(392138); + const rng = rand.random(); + + inline for (.{ + RepairRequest{ .Orphan = rng.int(Slot) }, + RepairRequest{ .Shred = .{ rng.int(Slot), rng.int(u64) } }, + RepairRequest{ .HighestShred = .{ rng.int(Slot), rng.int(u64) } }, + }) |request| { + var kp_noise: [32]u8 = undefined; + rng.bytes(&kp_noise); + const keypair = try KeyPair.create(kp_noise); + const recipient = Pubkey.random(rng); + const timestamp = rng.int(u64); + const nonce = rng.int(Nonce); + + var serialized = try serializeRepairRequest( + allocator, + request, + &keypair, + recipient, + timestamp, + nonce, + ); + defer allocator.free(serialized); + + var deserialized = try bincode.readFromSlice(allocator, RepairMessage, serialized, .{}); + try deserialized.verify(serialized, recipient, timestamp); + + serialized[10] = 0; // >99% chance that this invalidates the signature + var bad = try bincode.readFromSlice(allocator, RepairMessage, serialized, .{}); + if (bad.verify(serialized, recipient, timestamp)) |_| @panic("should err") else |_| {} + } +} + +test "tvu.repair_message: RepairRequestHeader serialization round trip" { + var rng = std.rand.DefaultPrng.init(5224); + var signature: [sig.core.SIGNATURE_LENGTH]u8 = undefined; + rng.fill(&signature); + + const header = RepairRequestHeader{ + .signature = Signature.init(signature), + .sender = Pubkey.random(rng.random()), + .recipient = Pubkey.random(rng.random()), + .timestamp = 5924, + .nonce = 123, + }; + + var buf: [RepairMessage.MAX_SERIALIZED_SIZE]u8 = undefined; + const serialized = try bincode.writeToSlice(&buf, header, .{}); + + const expected = [_]u8{ + 39, 95, 42, 53, 95, 32, 120, 241, 244, 206, 142, 80, 233, 26, 232, 206, 241, + 24, 226, 101, 183, 172, 170, 201, 42, 127, 121, 127, 213, 234, 180, 0, 226, 0, + 128, 58, 176, 144, 99, 139, 220, 112, 10, 117, 212, 239, 129, 197, 170, 11, 92, + 151, 239, 163, 174, 85, 172, 227, 75, 115, 1, 143, 134, 9, 21, 189, 8, 17, + 240, 55, 159, 41, 45, 133, 143, 153, 57, 113, 39, 28, 86, 183, 182, 76, 41, + 19, 160, 55, 54, 41, 126, 184, 144, 195, 245, 38, 164, 157, 171, 233, 18, 178, + 15, 2, 196, 46, 124, 59, 178, 108, 95, 194, 39, 18, 119, 16, 226, 118, 112, + 26, 255, 82, 27, 175, 162, 144, 207, 151, 36, 23, 0, 0, 0, 0, 0, 0, + 123, 0, 0, 0, + }; + + try std.testing.expect(std.mem.eql(u8, &expected, serialized)); + + const roundtripped = try bincode.readFromSlice( + std.testing.allocator, + RepairRequestHeader, + serialized, + .{}, + ); + try std.testing.expect(header.eql(&roundtripped)); +} + +test "tvu.repair_message: RepairProtocolMessage.Pong serialization round trip" { + try testHelpers.assertMessageSerializesCorrectly(57340, .Pong, &[_]u8{ + 7, 0, 0, 0, 252, 143, 181, 36, 240, 87, 69, 104, 157, 159, 242, 94, 101, + 48, 187, 120, 173, 241, 68, 167, 217, 67, 141, 46, 105, 85, 179, 69, 249, 140, + 6, 145, 6, 201, 32, 10, 11, 24, 157, 240, 245, 65, 91, 80, 255, 89, 18, + 136, 27, 80, 101, 106, 118, 175, 154, 105, 205, 69, 2, 112, 61, 168, 217, 197, + 251, 212, 16, 137, 153, 40, 116, 229, 235, 90, 12, 54, 76, 123, 187, 108, 132, + 78, 151, 13, 47, 0, 127, 182, 158, 5, 19, 226, 204, 0, 120, 218, 175, 155, + 122, 155, 94, 44, 198, 119, 196, 127, 121, 242, 98, 87, 235, 233, 241, 57, 53, + 125, 88, 67, 4, 23, 164, 128, 221, 124, 139, 84, 106, 7, + }); +} + +test "tvu.repair_message: RepairProtocolMessage.WindowIndex serialization round trip" { + try testHelpers.assertMessageSerializesCorrectly(4823794, .WindowIndex, &[_]u8{ + 8, 0, 0, 0, 100, 7, 241, 74, 194, 88, 24, 128, 85, 15, 149, 108, 142, + 133, 234, 217, 3, 79, 124, 171, 68, 30, 189, 219, 173, 11, 184, 159, 208, 104, + 206, 31, 233, 86, 166, 102, 235, 97, 198, 145, 62, 149, 19, 202, 91, 237, 153, + 175, 64, 205, 96, 10, 66, 7, 66, 104, 119, 214, 232, 34, 168, 170, 191, 254, + 170, 237, 236, 185, 88, 155, 113, 136, 171, 26, 210, 220, 45, 195, 26, 211, 174, + 235, 79, 241, 31, 60, 134, 15, 207, 28, 50, 96, 253, 80, 191, 140, 108, 58, + 53, 196, 143, 167, 65, 56, 105, 42, 146, 49, 136, 194, 147, 74, 110, 247, 135, + 48, 92, 138, 71, 230, 204, 175, 17, 87, 167, 45, 210, 99, 50, 122, 47, 19, + 19, 197, 58, 51, 19, 223, 45, 162, 128, 200, 255, 158, 217, 0, 235, 83, 78, + 233, 7, 127, 119, 47, 7, 223, + }); +} + +test "tvu.repair_message: RepairProtocolMessage.HighestWindowIndex serialization round trip" { + try testHelpers.assertMessageSerializesCorrectly(636345, .HighestWindowIndex, &[_]u8{ + 9, 0, 0, 0, 44, 123, 16, 108, 173, 151, 229, 132, 4, 0, 5, 215, 25, + 179, 235, 166, 181, 42, 30, 231, 218, 43, 166, 238, 92, 80, 234, 87, 30, 123, + 140, 27, 65, 165, 32, 139, 235, 225, 146, 239, 107, 162, 4, 80, 215, 131, 42, + 94, 28, 153, 26, 191, 57, 87, 214, 211, 145, 158, 113, 53, 178, 178, 33, 217, + 204, 75, 59, 119, 212, 148, 21, 154, 19, 106, 222, 14, 10, 225, 243, 182, 32, + 149, 101, 1, 226, 133, 56, 84, 175, 53, 65, 157, 177, 34, 153, 171, 107, 230, + 177, 30, 169, 141, 24, 248, 39, 184, 152, 55, 108, 199, 61, 232, 189, 152, 129, + 249, 88, 86, 204, 12, 134, 9, 185, 8, 176, 163, 50, 51, 149, 144, 227, 124, + 63, 248, 112, 172, 251, 252, 42, 232, 95, 7, 74, 139, 26, 36, 163, 156, 135, + 113, 204, 230, 147, 29, 223, 167, + }); +} + +test "tvu.repair_message: RepairProtocolMessage.Orphan serialization round trip" { + try testHelpers.assertMessageSerializesCorrectly(734566, .Orphan, &[_]u8{ + 10, 0, 0, 0, 52, 54, 182, 49, 197, 238, 253, 118, 145, 61, 198, 235, 42, + 211, 229, 42, 2, 33, 5, 161, 179, 171, 26, 243, 51, 240, 82, 98, 121, 90, + 210, 244, 120, 168, 226, 131, 209, 42, 251, 16, 90, 129, 113, 90, 195, 130, 55, + 58, 97, 240, 114, 59, 154, 38, 7, 66, 209, 77, 18, 17, 22, 1, 65, 184, + 202, 21, 198, 105, 238, 24, 115, 147, 78, 249, 178, 229, 75, 189, 129, 104, 138, + 75, 78, 30, 54, 222, 175, 51, 218, 247, 211, 188, 142, 76, 64, 156, 21, 191, + 163, 86, 38, 244, 0, 213, 69, 78, 102, 190, 220, 19, 138, 92, 30, 149, 125, + 135, 239, 186, 78, 147, 83, 128, 23, 200, 81, 2, 102, 110, 226, 11, 217, 50, + 27, 76, 129, 55, 218, 236, 152, 27, 164, 106, 186, 169, 80, 103, 36, 153, + }); +} + +test "tvu.repair_message: RepairProtocolMessage.AncestorHashes serialization round trip" { + try testHelpers.assertMessageSerializesCorrectly(6236757, .AncestorHashes, &[_]u8{ + 11, 0, 0, 0, 192, 86, 218, 156, 168, 139, 216, 200, 30, 181, 244, 121, 90, + 41, 177, 117, 55, 40, 199, 207, 62, 118, 56, 134, 73, 88, 74, 2, 139, 189, + 201, 150, 22, 75, 239, 15, 35, 125, 154, 130, 165, 120, 24, 154, 159, 42, 222, + 92, 189, 252, 136, 151, 184, 96, 137, 169, 181, 62, 108, 82, 235, 143, 42, 93, + 212, 223, 9, 217, 201, 202, 143, 14, 99, 140, 33, 48, 241, 185, 240, 10, 146, + 127, 62, 122, 247, 66, 91, 169, 32, 251, 220, 5, 197, 184, 172, 190, 182, 248, + 69, 46, 30, 121, 156, 153, 238, 91, 192, 207, 163, 187, 60, 71, 60, 232, 71, + 228, 195, 225, 162, 193, 230, 37, 128, 114, 73, 252, 29, 20, 164, 63, 220, 2, + 32, 166, 102, 87, 214, 59, 20, 255, 18, 190, 186, 206, 159, 97, 45, 99, + }); +} + +test "tvu.repair_message: RepairProtocolMessage serializes to size <= MAX_SERIALIZED_SIZE" { + var rng = std.rand.DefaultPrng.init(184837); + for (0..10) |_| { + inline for (@typeInfo(RepairMessage.Tag).Enum.fields) |enum_field| { + const tag = @field(RepairMessage.Tag, enum_field.name); + const msg = testHelpers.randomRepairProtocolMessage(rng.random(), tag); + var buf: [RepairMessage.MAX_SERIALIZED_SIZE]u8 = undefined; + _ = try bincode.writeToSlice(&buf, msg, .{}); + } + } +} + +const testHelpers = struct { + fn assertMessageSerializesCorrectly( + seed: u64, + tag: RepairMessage.Tag, + expected: []const u8, + ) !void { + var rng = std.rand.DefaultPrng.init(seed); + const msg = testHelpers.randomRepairProtocolMessage(rng.random(), tag); + debugMessage(&msg); + + var buf: [RepairMessage.MAX_SERIALIZED_SIZE]u8 = undefined; + const serialized = try bincode.writeToSlice(&buf, msg, .{}); + try std.testing.expect(std.mem.eql(u8, expected, serialized)); + + switch (msg) { + .Pong => |_| try msg.verify(serialized, undefined, 0), + inline else => |m| { + const result = msg.verify(serialized, m.header.recipient, m.header.timestamp); + if (result) |_| @panic("should fail due to signature") else |_| {} + }, + } + + const roundtripped = try bincode.readFromSlice( + std.testing.allocator, + RepairMessage, + serialized, + .{}, + ); + try std.testing.expect(msg.eql(&roundtripped)); + + // // rust template to generate expectation: + // let header = RepairRequestHeader { + // signature: Signature::new(&[]), + // sender: Pubkey::from([]), + // recipient: Pubkey::from([]), + // timestamp: , + // nonce: , + // }; + // let msg = RepairProtocol::AncestorHashes { + // header, + // slot: , + // }; + // let data = bincode::serialize(&msg).unwrap(); + // println!("{data:?}"); + } + + fn randomRepairRequestHeader(rng: std.rand.Random) RepairRequestHeader { + var signature: [sig.core.SIGNATURE_LENGTH]u8 = undefined; + rng.bytes(&signature); + + return RepairRequestHeader{ + .signature = Signature.init(signature), + .sender = Pubkey.random(rng), + .recipient = Pubkey.random(rng), + .timestamp = rng.int(u64), + .nonce = rng.int(u32), + }; + } + + fn randomRepairProtocolMessage( + rng: std.rand.Random, + message_type: RepairMessage.Tag, + ) RepairMessage { + return switch (message_type) { + .Pong => x: { + var buf: [32]u8 = undefined; + rng.bytes(&buf); + const kp = KeyPair.create(buf) catch unreachable; + break :x .{ .Pong = Pong.random(rng, &(kp)) catch unreachable }; + }, + .WindowIndex => .{ .WindowIndex = .{ + .header = randomRepairRequestHeader(rng), + .slot = rng.int(Slot), + .shred_index = rng.int(u64), + } }, + .HighestWindowIndex => .{ .HighestWindowIndex = .{ + .header = randomRepairRequestHeader(rng), + .slot = rng.int(Slot), + .shred_index = rng.int(u64), + } }, + .Orphan => .{ .Orphan = .{ + .header = randomRepairRequestHeader(rng), + .slot = rng.int(Slot), + } }, + .AncestorHashes => .{ .AncestorHashes = .{ + .header = randomRepairRequestHeader(rng), + .slot = rng.int(Slot), + } }, + }; + } + + const DEBUG: bool = false; + + fn debugMessage(message: *const RepairMessage) void { + if (!DEBUG) return; + std.debug.print("_\n\n", .{}); + switch (message.*) { + .Pong => |*msg| { + std.debug.print("from: {any}\n\n", .{msg.from}); + std.debug.print("hash: {any}\n\n", .{msg.hash}); + std.debug.print("signature: {any}\n\n", .{msg.signature}); + }, + .WindowIndex => |*msg| { + debugHeader(msg.header); + }, + .HighestWindowIndex => |*msg| { + debugHeader(msg.header); + }, + .Orphan => |*msg| { + debugHeader(msg.header); + }, + .AncestorHashes => |*msg| { + debugHeader(msg.header); + }, + } + std.debug.print("{any}", .{message}); + } + + fn debugHeader(header: RepairRequestHeader) void { + if (!DEBUG) return; + std.debug.print("nonce: {any}\n\n", .{header.nonce}); + std.debug.print("recipient: {any}\n\n", .{header.recipient.data}); + std.debug.print("sender: {any}\n\n", .{header.sender.data}); + std.debug.print("signature: {any}\n\n", .{header.signature.data}); + std.debug.print("timestamp: {any}\n\n", .{header.timestamp}); + } +}; diff --git a/src/tvu/repair_service.zig b/src/tvu/repair_service.zig new file mode 100644 index 000000000..51f94bb5e --- /dev/null +++ b/src/tvu/repair_service.zig @@ -0,0 +1,436 @@ +const std = @import("std"); +const zig_network = @import("zig-network"); +const sig = @import("../lib.zig"); + +const Allocator = std.mem.Allocator; +const Atomic = std.atomic.Atomic; +const KeyPair = std.crypto.sign.Ed25519.KeyPair; +const Random = std.rand.Random; +const Socket = zig_network.Socket; + +const ContactInfo = sig.gossip.ContactInfo; +const GossipTable = sig.gossip.GossipTable; +const Logger = sig.trace.Logger; +const LruCacheCustom = sig.common.LruCacheCustom; +const Nonce = sig.core.Nonce; +const Pubkey = sig.core.Pubkey; +const RwMux = sig.sync.RwMux; +const SocketAddr = sig.net.SocketAddr; +const Slot = sig.core.Slot; + +const RepairRequest = sig.tvu.RepairRequest; +const serializeRepairRequest = sig.tvu.serializeRepairRequest; + +/// Identifies which repairs are needed and sends them +/// - delegates to RepairPeerProvider to identify repair peers. +/// - delegates to RepairRequester to send the requests. +pub const RepairService = struct { + allocator: Allocator, + requester: RepairRequester, + peer_provider: RepairPeerProvider, + logger: Logger, + exit: *Atomic(bool), + slot_to_request: ?u64, + + pub fn deinit(self: *@This()) void { + self.peer_provider.deinit(); + } + + pub fn run(self: *@This()) !void { + self.logger.info("starting repair service"); + defer self.logger.info("exiting repair service"); + while (!self.exit.load(.Unordered)) { + if (try self.initialSnapshotRepair()) |request| { + try self.requester.sendRepairRequest(request); + } + // TODO repair logic + std.time.sleep(100_000_000); + } + } + + fn initialSnapshotRepair(self: *@This()) !?AddressedRepairRequest { + if (self.slot_to_request == null) return null; + const request: RepairRequest = .{ .HighestShred = .{ self.slot_to_request.?, 0 } }; + const maybe_peer = try self.peer_provider.getRandomPeer(self.slot_to_request.?); + + if (maybe_peer) |peer| return .{ + .request = request, + .recipient = peer.pubkey, + .recipient_addr = peer.serve_repair_socket, + } else { + return null; + } + } +}; + +/// Signs and serializes repair requests. Sends them over the network. +pub const RepairRequester = struct { + allocator: Allocator, + rng: Random, + keypair: *const KeyPair, + udp_send_socket: *Socket, + logger: Logger, + + pub fn sendRepairRequest( + self: *const @This(), + request: AddressedRepairRequest, + ) !void { + const timestamp = std.time.milliTimestamp(); + const data = try serializeRepairRequest( + self.allocator, + request.request, + self.keypair, + request.recipient, + @intCast(timestamp), + self.rng.int(Nonce), + ); + defer self.allocator.free(data); + const addr = request.recipient_addr.toString(); + self.logger.infof( + "sending repair request to {s} - {}", + .{ addr[0][0..addr[1]], request.request }, + ); + _ = try self.udp_send_socket.sendTo(request.recipient_addr.toEndpoint(), data); + } +}; + +/// A repair request plus its destination. +pub const AddressedRepairRequest = struct { + request: RepairRequest, + recipient: Pubkey, + recipient_addr: SocketAddr, +}; + +/// How many slots to cache in RepairPeerProvider +const REPAIR_PEERS_CACHE_CAPACITY: usize = 128; +/// Maximum age of a cache item to use for repair peers +const REPAIR_PEERS_CACHE_TTL_SECONDS: u64 = 10; + +/// A node that can service a repair request. +pub const RepairPeer = struct { + pubkey: Pubkey, + serve_repair_socket: SocketAddr, +}; + +/// Provides peers that repair requests can be sent to. +/// +/// TODO benchmark the performance of some alternate approaches, for example: +/// - directly grab a single random peer from gossip instead +/// of the entire list (good if we don't access a slot many times) +/// - single sorted cache for all slots with a binary search to filter by slot +/// - upside is fewer table locks +/// - good if we're mainly looking at older slots, not the last few slots +/// - downside is it may get stale and not represent the latest slots, +/// unless cache TTL is reduced to less than 1 slot duration, in which +/// case it may defeat the purpose of this approach. +/// The key for these benchmarks is to understand the actual repair requests that +/// are being requested on mainnet. There are trade-offs for different kinds +/// of requests. Naive benchmarks will optimize the wrong behaviors. +pub const RepairPeerProvider = struct { + allocator: Allocator, + rng: Random, + gossip_table_rw: *RwMux(GossipTable), + cache: LruCacheCustom(.non_locking, Slot, RepairPeers, Allocator, RepairPeers.deinit), + my_pubkey: Pubkey, + my_shred_version: *const Atomic(u16), + + const RepairPeers = struct { + insertion_time_secs: u64, + peers: []RepairPeer, + + fn deinit(self: *@This(), allocator: Allocator) void { + allocator.free(self.peers); + } + }; + + pub fn init( + allocator: Allocator, + rng: Random, + gossip: *RwMux(GossipTable), + my_pubkey: Pubkey, + my_shred_version: *const Atomic(u16), + ) error{OutOfMemory}!RepairPeerProvider { + return .{ + .allocator = allocator, + .gossip_table_rw = gossip, + .cache = try LruCacheCustom(.non_locking, Slot, RepairPeers, Allocator, RepairPeers.deinit) + .initWithContext(allocator, REPAIR_PEERS_CACHE_CAPACITY, allocator), + .my_pubkey = my_pubkey, + .my_shred_version = my_shred_version, + .rng = rng, + }; + } + + pub fn deinit(self: *@This()) void { + self.cache.deinit(); + } + + /// Selects a peer at random from gossip or cache that is expected + /// to be able to handle a repair request for the specified slot. + pub fn getRandomPeer(self: *@This(), slot: Slot) !?RepairPeer { + const peers = try self.getPeers(slot); + if (peers.len == 0) return null; + const index = self.rng.intRangeLessThan(usize, 0, peers.len); + return peers[index]; + } + + /// Tries to get peers that could have the slot. Checks cache, falling back to gossip. + fn getPeers(self: *@This(), slot: Slot) ![]RepairPeer { + const now: u64 = @intCast(std.time.timestamp()); + + if (self.cache.get(slot)) |peers| { + if (now - peers.insertion_time_secs <= REPAIR_PEERS_CACHE_TTL_SECONDS) { + return peers.peers; + } + } + + const peers = try self.getRepairPeersFromGossip(self.allocator, slot); + try self.cache.insert(slot, .{ + .insertion_time_secs = now, + .peers = peers, + }); + return peers; + } + + /// Gets a list of peers from the gossip table that are likely to have the desired slot. + /// This will always acquire the gossip table lock. + /// Instead of using this function, access the cache when possible to avoid contention. + fn getRepairPeersFromGossip( + self: *@This(), + allocator: Allocator, + slot: Slot, + ) error{OutOfMemory}![]RepairPeer { + var gossip_table_lock = self.gossip_table_rw.read(); + defer gossip_table_lock.unlock(); + const gossip_table: *const GossipTable = gossip_table_lock.get(); + const buf = try allocator.alloc(RepairPeer, gossip_table.contact_infos.count()); + errdefer allocator.free(buf); + var i: usize = 0; + var infos = gossip_table.contactInfoIterator(0); + while (infos.next()) |info| { + const serve_repair_socket = info.getSocket(sig.gossip.socket_tag.SERVE_REPAIR); + if (!info.pubkey.equals(&self.my_pubkey) and // don't request from self + info.shred_version == self.my_shred_version.load(.Monotonic) and // need compatible shreds + serve_repair_socket != null and // node must be able to receive repair requests + info.getSocket(sig.gossip.socket_tag.TVU) != null) // node needs access to shreds + { + // exclude nodes that are known to be missing this slot + if (gossip_table.get(.{ .LowestSlot = info.pubkey })) |lsv| { + if (lsv.value.data.LowestSlot[1].lowest > slot) { + continue; + } + } + buf[i] = .{ + .pubkey = info.pubkey, + .serve_repair_socket = serve_repair_socket.?, + }; + i += 1; + } + } + return try allocator.realloc(buf, i); + } +}; + +test "tvu.repair_service: RepairService sends repair request to gossip peer" { + const SignedGossipData = sig.gossip.SignedGossipData; + const allocator = std.testing.allocator; + var rand = std.rand.DefaultPrng.init(4328095); + var random = rand.random(); + + // my details + const keypair = KeyPair.create(null) catch unreachable; + const my_shred_version = Atomic(u16).init(random.int(u16)); + const wallclock = 100; + var gossip = try GossipTable.init(allocator, undefined); + defer gossip.deinit(); + var logger = Logger.init(allocator, Logger.TEST_DEFAULT_LEVEL); + defer logger.deinit(); + + // connectivity + const repair_port = random.intRangeAtMost(u16, 1000, std.math.maxInt(u16)); + var repair_socket = try Socket.create(.ipv4, .udp); + try repair_socket.bind(.{ + .port = repair_port, + .address = .{ .ipv4 = .{ .value = .{ 0, 0, 0, 0 } } }, + }); + + // peer + const peer_port = random.intRangeAtMost(u16, 1000, std.math.maxInt(u16)); + const peer_keypair = KeyPair.create(null) catch unreachable; + var peer_socket = try Socket.create(.ipv4, .udp); + const peer_endpoint = .{ + .address = .{ .ipv4 = .{ .value = .{ 127, 0, 0, 1 } } }, + .port = peer_port, + }; + try peer_socket.bind(peer_endpoint); + try peer_socket.setReadTimeout(100_000); + var peer_contact_info = ContactInfo.init(allocator, Pubkey.fromPublicKey(&peer_keypair.public_key), wallclock, my_shred_version.load(.Unordered)); + try peer_contact_info.setSocket(sig.gossip.socket_tag.SERVE_REPAIR, SocketAddr.fromEndpoint(&peer_endpoint)); + try peer_contact_info.setSocket(sig.gossip.socket_tag.TVU, SocketAddr.fromEndpoint(&peer_endpoint)); + try gossip.insert(try SignedGossipData.initSigned(.{ .ContactInfo = peer_contact_info }, &peer_keypair), wallclock); + + // init service + var exit = Atomic(bool).init(false); + var gossip_mux = RwMux(GossipTable).init(gossip); + var peers = try RepairPeerProvider.init( + allocator, + random, + &gossip_mux, + Pubkey.fromPublicKey(&keypair.public_key), + &my_shred_version, + ); + var service = RepairService{ + .allocator = allocator, + .requester = RepairRequester{ + .allocator = allocator, + .rng = random, + .udp_send_socket = &repair_socket, + .keypair = &keypair, + .logger = logger, + }, + .peer_provider = peers, + .logger = logger, + .exit = &exit, + .slot_to_request = 13579, + }; + defer service.deinit(); + + // run test + const handle = try std.Thread.spawn(.{}, RepairService.run, .{&service}); + var buf: [200]u8 = undefined; + const size = peer_socket.receive(&buf) catch 0; + + // assertions + try std.testing.expect(160 == size); + const msg = try sig.bincode.readFromSlice(allocator, sig.tvu.RepairMessage, buf[0..160], .{}); + try msg.verify(buf[0..160], Pubkey.fromPublicKey(&peer_keypair.public_key), @intCast(std.time.milliTimestamp())); + try std.testing.expect(msg.HighestWindowIndex.slot == 13579); + try std.testing.expect(msg.HighestWindowIndex.shred_index == 0); + + // exit + exit.store(true, .Monotonic); + handle.join(); +} + +test "tvu.repair_service: RepairPeerProvider selects correct peers" { + const allocator = std.testing.allocator; + var rand = std.rand.DefaultPrng.init(4328095); + var random = rand.random(); + + // my details + const keypair = KeyPair.create(null) catch unreachable; + const my_shred_version = Atomic(u16).init(random.int(u16)); + var gossip = try GossipTable.init(allocator, undefined); + defer gossip.deinit(); + var logger = Logger.init(allocator, Logger.TEST_DEFAULT_LEVEL); + defer logger.deinit(); + + // peers + const peer_generator = TestPeerGenerator{ + .allocator = allocator, + .gossip = &gossip, + .random = random, + .shred_version = my_shred_version.load(.Unordered), + .slot = 13579, + }; + const good_peers = .{ + try peer_generator.addPeerToGossip(.HasSlot), + try peer_generator.addPeerToGossip(.SlotPosessionUnclear), + }; + const bad_peers = .{ + try peer_generator.addPeerToGossip(.MissingServeRepairPort), + try peer_generator.addPeerToGossip(.MissingTvuPort), + try peer_generator.addPeerToGossip(.MissingSlot), + try peer_generator.addPeerToGossip(.WrongShredVersion), + }; + + // init test subject + var gossip_mux = RwMux(GossipTable).init(gossip); + var peers = try RepairPeerProvider.init( + allocator, + random, + &gossip_mux, + Pubkey.fromPublicKey(&keypair.public_key), + &my_shred_version, + ); + defer peers.deinit(); + + // run test + var observed_peers = std.AutoHashMap(RepairPeer, void).init(allocator); + defer observed_peers.deinit(); + for (0..10) |_| { + try observed_peers.put(try peers.getRandomPeer(13579) orelse unreachable, {}); + } + + // assertions + var failed = false; + inline for (good_peers) |good_peer| { + if (!observed_peers.contains(good_peer[1])) { + std.debug.print("_\nMISSING: {}\n", .{good_peer[0]}); + failed = true; + } + } + inline for (bad_peers) |bad_peer| { + if (observed_peers.contains(bad_peer[1])) { + std.debug.print("_\nUNEXPECTED: {}\n", .{bad_peer[0]}); + failed = true; + } + } + try std.testing.expect(!failed); +} + +const TestPeerGenerator = struct { + allocator: Allocator, + gossip: *GossipTable, + random: Random, + shred_version: u16, + slot: Slot, + + const PeerType = enum { + /// There is a LowestSlot for the peer that indicates they have the slot + HasSlot, + /// There is not a LowestSlot + SlotPosessionUnclear, + /// There is a LowestSlot for the peer that indicates they do not have the slot + MissingSlot, + /// There is no serve repair port specified in the peer's contact info + MissingServeRepairPort, + /// There is no tvu port specified in the peer's contact info + MissingTvuPort, + /// The peer has a different shred version + WrongShredVersion, + }; + + fn addPeerToGossip(self: *const @This(), peer_type: PeerType) !struct { PeerType, RepairPeer } { + const SignedGossipData = sig.gossip.SignedGossipData; + const wallclock = 1; + const keypair = KeyPair.create(null) catch unreachable; + const serve_repair_addr = SocketAddr.initIpv4(.{ 127, 0, 0, 1 }, 8003); + const shred_version = if (peer_type == .WrongShredVersion) self.shred_version + 1 else self.shred_version; + const pubkey = Pubkey.fromPublicKey(&keypair.public_key); + var contact_info = ContactInfo.init(self.allocator, pubkey, wallclock, shred_version); + if (peer_type != .MissingServeRepairPort) { + try contact_info.setSocket(sig.gossip.socket_tag.SERVE_REPAIR, serve_repair_addr); + } + if (peer_type != .MissingTvuPort) { + try contact_info.setSocket(sig.gossip.socket_tag.TVU, SocketAddr.initIpv4(.{ 127, 0, 0, 1 }, 8004)); + } + try self.gossip.insert(try SignedGossipData.initSigned(.{ .ContactInfo = contact_info }, &keypair), wallclock); + switch (peer_type) { + inline .HasSlot, .MissingSlot => { + var lowest_slot = sig.gossip.LowestSlot.random(self.random); + lowest_slot.from = pubkey; + lowest_slot.lowest = switch (peer_type) { + .MissingSlot => self.slot + 1, + else => self.slot, + }; + try self.gossip.insert(try SignedGossipData.initSigned(.{ .LowestSlot = .{ 0, lowest_slot } }, &keypair), wallclock); + }, + else => {}, + } + return .{ peer_type, .{ + .pubkey = pubkey, + .serve_repair_socket = serve_repair_addr, + } }; + } +}; diff --git a/src/tvu/shred_receiver.zig b/src/tvu/shred_receiver.zig new file mode 100644 index 000000000..7efe3d116 --- /dev/null +++ b/src/tvu/shred_receiver.zig @@ -0,0 +1,107 @@ +const std = @import("std"); +const sig = @import("../lib.zig"); +const network = @import("zig-network"); + +const bincode = sig.bincode; + +const Allocator = std.mem.Allocator; +const ArrayList = std.ArrayList; +const Atomic = std.atomic.Atomic; +const KeyPair = std.crypto.sign.Ed25519.KeyPair; +const Socket = network.Socket; + +const Channel = sig.sync.Channel; +const Logger = sig.trace.Logger; +const Packet = sig.net.Packet; +const Ping = sig.gossip.Ping; +const Pong = sig.gossip.Pong; +const RepairMessage = sig.tvu.RepairMessage; +const SocketThread = sig.net.SocketThread; + +/// Analogous to `ShredFetchStage` +pub const ShredReceiver = struct { + allocator: Allocator, + keypair: *const KeyPair, + exit: *Atomic(bool), + logger: Logger, + socket: *Socket, + + const Self = @This(); + + /// Run threads to listen/send over socket and handle all incoming packets. + /// Returns when exit is set to true. + pub fn run(self: *Self) !void { + defer self.logger.err("exiting shred receiver"); + errdefer self.logger.err("error in shred receiver"); + + var sender = try SocketThread.initSender(self.allocator, self.logger, self.socket, self.exit); + defer sender.deinit(); + var receiver = try SocketThread.initReceiver(self.allocator, self.logger, self.socket, self.exit); + defer receiver.deinit(); + + try self.runPacketHandler(receiver.channel, sender.channel); + } + + /// Keep looping over packet channel and process the incoming packets. + /// Returns when exit is set to true. + fn runPacketHandler( + self: *Self, + receiver: *Channel(ArrayList(Packet)), + sender: *Channel(ArrayList(Packet)), + ) !void { + while (!self.exit.load(.Unordered)) { + var responses = ArrayList(Packet).init(self.allocator); + if (try receiver.try_drain()) |batches| { + for (batches) |batch| for (batch.items) |*packet| { + try self.handlePacket(packet, &responses); + }; + if (responses.items.len > 0) { + try sender.send(responses); + } + } else { + std.time.sleep(10_000_000); + } + } + } + + /// Handle a single packet and return + fn handlePacket(self: *Self, packet: *const Packet, responses: *ArrayList(Packet)) !void { + if (packet.size == REPAIR_RESPONSE_SERIALIZED_PING_BYTES) { + try self.handlePing(packet, responses); + } else { + const endpoint_str = try sig.net.endpointToString(self.allocator, &packet.addr); + defer endpoint_str.deinit(); + self.logger.field("from_endpoint", endpoint_str.items) + .infof("tvu: recv unknown shred message: {} bytes", .{packet.size}); + } + } + + /// Handle a ping message and return + fn handlePing(self: *Self, packet: *const Packet, responses: *ArrayList(Packet)) !void { + const repair_ping = bincode.readFromSlice(self.allocator, RepairPing, &packet.data, .{}) catch |e| { + self.logger.errf("could not deserialize ping: {} - {any}", .{ e, packet.data[0..packet.size] }); + return; + }; + const ping = repair_ping.Ping; + ping.verify() catch |e| { + self.logger.errf("ping failed verification: {} - {any}", .{ e, packet.data[0..packet.size] }); + return; + }; + + const reply = RepairMessage{ .Pong = try Pong.init(&ping, self.keypair) }; + const reply_packet = try responses.addOne(); + reply_packet.addr = packet.addr; + const reply_bytes = try bincode.writeToSlice(&reply_packet.data, reply, .{}); + reply_packet.size = reply_bytes.len; + + const endpoint_str = try sig.net.endpointToString(self.allocator, &packet.addr); + defer endpoint_str.deinit(); + self.logger.field("from_endpoint", endpoint_str.items) + .field("from_pubkey", &ping.from.string()) + .info("tvu: recv repair ping"); + } +}; + +const REPAIR_RESPONSE_SERIALIZED_PING_BYTES = 132; + +const RepairPing = union(enum) { Ping: Ping };