diff --git a/pkg/rewards/claim.go b/pkg/rewards/claim.go index 2e8d384d..6a825bcd 100644 --- a/pkg/rewards/claim.go +++ b/pkg/rewards/claim.go @@ -41,6 +41,7 @@ type elChainReader interface { ctx context.Context, ) (rewardscoordinator.IRewardsCoordinatorDistributionRoot, error) CurrRewardsCalculationEndTimestamp(ctx context.Context) (uint32, error) + GetCumulativeClaimed(ctx context.Context, earnerAddress, tokenAddress gethcommon.Address) (*big.Int, error) } func ClaimCmd(p utils.Prompter) *cli.Command { @@ -123,15 +124,22 @@ func Claim(cCtx *cli.Context, p utils.Prompter) error { return eigenSdkUtils.WrapError("failed to fetch claim amounts for date", err) } - claimableTokens, present := proofData.Distribution.GetTokensForEarner(config.EarnerAddress) + claimableTokensOrderMap, present := proofData.Distribution.GetTokensForEarner(config.EarnerAddress) if !present { return errors.New("no tokens claimable by earner") } + claimableTokensMap := getTokensToClaim(claimableTokensOrderMap, config.TokenAddresses) + + claimableTokens, err := filterClaimableTokens(ctx, elReader, config.EarnerAddress, claimableTokensMap) + if err != nil { + return eigenSdkUtils.WrapError("failed to get claimable tokens", err) + } + cg := claimgen.NewClaimgen(proofData.Distribution) accounts, claim, err := cg.GenerateClaimProofForEarner( config.EarnerAddress, - getTokensToClaim(claimableTokens, config.TokenAddresses), + claimableTokens, rootIndex, ) if err != nil { @@ -270,6 +278,30 @@ func Claim(cCtx *cli.Context, p utils.Prompter) error { return nil } +// filterClaimableTokens to filter out tokens that have been fully claimed +func filterClaimableTokens( + ctx context.Context, + elReader elChainReader, + earnerAddress gethcommon.Address, + claimableTokensMap map[gethcommon.Address]*big.Int, +) ([]gethcommon.Address, error) { + claimableTokens := make([]gethcommon.Address, 0) + for token, claimedAmount := range claimableTokensMap { + amount, err := getCummulativeClaimedRewards(ctx, elReader, earnerAddress, token) + if err != nil { + return nil, err + } + // If the token has been claimed fully, we don't need to include it in the claim + // This is because contracts reject claims for tokens that have been fully claimed + // https://github.com/Layr-Labs/eigenlayer-contracts/blob/ac57bc1b28c83d9d7143c0da19167c148c3596a3/src/contracts/core/RewardsCoordinator.sol#L575-L578 + if claimedAmount.Cmp(amount) == 0 { + continue + } + claimableTokens = append(claimableTokens, token) + } + return claimableTokens, nil +} + func getClaimDistributionRoot( ctx context.Context, claimTimestamp string, @@ -312,39 +344,40 @@ func getClaimDistributionRoot( func getTokensToClaim( claimableTokens *orderedmap.OrderedMap[gethcommon.Address, *distribution.BigInt], tokenAddresses []gethcommon.Address, -) []gethcommon.Address { +) map[gethcommon.Address]*big.Int { + var tokenMap map[gethcommon.Address]*big.Int if len(tokenAddresses) == 0 { - tokenAddresses = getAllClaimableTokenAddresses(claimableTokens) + tokenMap = getAllClaimableTokenAddresses(claimableTokens) } else { - tokenAddresses = filterClaimableTokenAddresses(claimableTokens, tokenAddresses) + tokenMap = filterClaimableTokenAddresses(claimableTokens, tokenAddresses) } - return tokenAddresses + return tokenMap } func getAllClaimableTokenAddresses( addressesMap *orderedmap.OrderedMap[gethcommon.Address, *distribution.BigInt], -) []gethcommon.Address { - var addresses []gethcommon.Address +) map[gethcommon.Address]*big.Int { + tokens := make(map[gethcommon.Address]*big.Int) for pair := addressesMap.Oldest(); pair != nil; pair = pair.Next() { - addresses = append(addresses, pair.Key) + tokens[pair.Key] = pair.Value.Int } - return addresses + return tokens } func filterClaimableTokenAddresses( addressesMap *orderedmap.OrderedMap[gethcommon.Address, *distribution.BigInt], providedAddresses []gethcommon.Address, -) []gethcommon.Address { - var addresses []gethcommon.Address +) map[gethcommon.Address]*big.Int { + tokens := make(map[gethcommon.Address]*big.Int) for _, address := range providedAddresses { - if _, ok := addressesMap.Get(address); ok { - addresses = append(addresses, address) + if val, ok := addressesMap.Get(address); ok { + tokens[address] = val.Int } } - return addresses + return tokens } func convertClaimTokenLeaves( diff --git a/pkg/rewards/claim_test.go b/pkg/rewards/claim_test.go index 77d83ada..6d9f9e81 100644 --- a/pkg/rewards/claim_test.go +++ b/pkg/rewards/claim_test.go @@ -27,10 +27,14 @@ import ( ) type fakeELReader struct { - roots []rewardscoordinator.IRewardsCoordinatorDistributionRoot + roots []rewardscoordinator.IRewardsCoordinatorDistributionRoot + earnerTokenClaimedMap map[common.Address]map[common.Address]*big.Int } -func newFakeELReader(now time.Time) *fakeELReader { +func newFakeELReader( + now time.Time, + earnerTokenClaimedMap map[common.Address]map[common.Address]*big.Int, +) *fakeELReader { roots := make([]rewardscoordinator.IRewardsCoordinatorDistributionRoot, 0) rootOne := rewardscoordinator.IRewardsCoordinatorDistributionRoot{ Root: [32]byte{0x01}, @@ -60,7 +64,8 @@ func newFakeELReader(now time.Time) *fakeELReader { return roots[i].ActivatedAt < roots[j].ActivatedAt }) return &fakeELReader{ - roots: roots, + roots: roots, + earnerTokenClaimedMap: earnerTokenClaimedMap, } } @@ -91,6 +96,21 @@ func (f *fakeELReader) GetCurrentClaimableDistributionRoot( return rewardscoordinator.IRewardsCoordinatorDistributionRoot{}, errors.New("no active distribution root found") } +func (f *fakeELReader) GetCumulativeClaimed( + ctx context.Context, + earnerAddress, + tokenAddress common.Address, +) (*big.Int, error) { + if f.earnerTokenClaimedMap == nil { + return big.NewInt(0), nil + } + claimed, ok := f.earnerTokenClaimedMap[earnerAddress][tokenAddress] + if !ok { + return big.NewInt(0), nil + } + return claimed, nil +} + func (f *fakeELReader) CurrRewardsCalculationEndTimestamp(ctx context.Context) (uint32, error) { rootLen, err := f.GetDistributionRootsLength(ctx) if err != nil { @@ -246,7 +266,7 @@ func TestGetClaimDistributionRoot(t *testing.T) { }, } - reader := newFakeELReader(now) + reader := newFakeELReader(now, nil) logger := logging.NewJsonSLogger(os.Stdout, &logging.SLoggerOptions{}) for _, tt := range tests { @@ -280,13 +300,18 @@ func TestGetTokensToClaim(t *testing.T) { // Case 1: No token addresses provided, should return all addresses in claimableTokens result := getTokensToClaim(claimableTokens, []common.Address{}) - expected := []common.Address{addr1, addr2} - assert.ElementsMatch(t, result, expected) + expected := map[common.Address]*big.Int{ + addr1: big.NewInt(100), + addr2: big.NewInt(200), + } + assert.Equal(t, result, expected) // Case 2: Provided token addresses, should return only those present in claimableTokens result = getTokensToClaim(claimableTokens, []common.Address{addr2, addr3}) - expected = []common.Address{addr2} - assert.ElementsMatch(t, result, expected) + expected = map[common.Address]*big.Int{ + addr2: big.NewInt(200), + } + assert.Equal(t, result, expected) } func TestGetTokenAddresses(t *testing.T) { @@ -300,8 +325,11 @@ func TestGetTokenAddresses(t *testing.T) { // Test that the function returns all addresses in the map result := getAllClaimableTokenAddresses(addressesMap) - expected := []common.Address{addr1, addr2} - assert.ElementsMatch(t, result, expected) + expected := map[common.Address]*big.Int{ + addr1: big.NewInt(100), + addr2: big.NewInt(200), + } + assert.Equal(t, result, expected) } func TestFilterClaimableTokenAddresses(t *testing.T) { @@ -321,8 +349,70 @@ func TestFilterClaimableTokenAddresses(t *testing.T) { } result := filterClaimableTokenAddresses(addressesMap, providedAddresses) - expected := []common.Address{addr1} - assert.ElementsMatch(t, result, expected) + expected := map[common.Address]*big.Int{ + addr1: big.NewInt(100), + } + assert.Equal(t, result, expected) +} + +func TestFilterClaimableTokens(t *testing.T) { + // Set up a mock claimableTokens map + earnerAddress := common.HexToAddress(testutils.GenerateRandomEthereumAddressString()) + tokenAddress1 := common.HexToAddress(testutils.GenerateRandomEthereumAddressString()) + tokenAddress2 := common.HexToAddress(testutils.GenerateRandomEthereumAddressString()) + amountClaimed1 := big.NewInt(100) + amountClaimed2 := big.NewInt(200) + elReaderClaimedMap := map[common.Address]map[common.Address]*big.Int{ + earnerAddress: { + tokenAddress1: amountClaimed1, + tokenAddress2: amountClaimed2, + }, + } + now := time.Now() + reader := newFakeELReader(now, elReaderClaimedMap) + tests := []struct { + name string + earnerAddress common.Address + claimableTokensMap map[common.Address]*big.Int + expectedClaimableTokens []common.Address + }{ + { + name: "all tokens are claimable and non zero", + earnerAddress: earnerAddress, + claimableTokensMap: map[common.Address]*big.Int{ + tokenAddress1: big.NewInt(2345), + tokenAddress2: big.NewInt(3345), + }, + expectedClaimableTokens: []common.Address{ + tokenAddress1, + tokenAddress2, + }, + }, + { + name: "one token is already claimed", + earnerAddress: earnerAddress, + claimableTokensMap: map[common.Address]*big.Int{ + tokenAddress1: amountClaimed1, + tokenAddress2: big.NewInt(1234), + }, + expectedClaimableTokens: []common.Address{ + tokenAddress2, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := filterClaimableTokens( + context.Background(), + reader, + tt.earnerAddress, + tt.claimableTokensMap, + ) + assert.NoError(t, err) + assert.ElementsMatch(t, tt.expectedClaimableTokens, result) + }) + } } func newBigInt(value int64) *distribution.BigInt { diff --git a/pkg/rewards/show.go b/pkg/rewards/show.go index e86b934b..0d646761 100644 --- a/pkg/rewards/show.go +++ b/pkg/rewards/show.go @@ -169,18 +169,31 @@ func getClaimedRewards( ) (map[gethcommon.Address]*big.Int, error) { claimedRewards := make(map[gethcommon.Address]*big.Int) for address := range allRewards { - claimed, err := elReader.GetCumulativeClaimed(ctx, earnerAddress, address) + claimed, err := getCummulativeClaimedRewards(ctx, elReader, earnerAddress, address) if err != nil { return nil, err } - if claimed == nil { - claimed = big.NewInt(0) - } claimedRewards[address] = claimed } return claimedRewards, nil } +func getCummulativeClaimedRewards( + ctx context.Context, + elReader ELReader, + earnerAddress gethcommon.Address, + tokenAddress gethcommon.Address, +) (*big.Int, error) { + claimed, err := elReader.GetCumulativeClaimed(ctx, earnerAddress, tokenAddress) + if err != nil { + return nil, err + } + if claimed == nil { + claimed = big.NewInt(0) + } + return claimed, nil +} + func calculateUnclaimedRewards( allRewards, claimedRewards map[gethcommon.Address]*big.Int,