Skip to content

Commit

Permalink
gshift as gforce optimized shift (#5205)
Browse files Browse the repository at this point in the history
  • Loading branch information
ben-schwen authored Oct 20, 2021
1 parent fa76197 commit e88826e
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 7 deletions.
34 changes: 34 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,40 @@

29. `setkey()` now supports type `raw` as value columns (not as key columns), [#5100](https://github.com/Rdatatable/data.table/issues/5100). Thanks Hugh Parsonage for requesting, and Benjamin Schwendinger for the PR.

30. `shift()` is now optimised by group, [#1534](https://github.com/Rdatatable/data.table/issues/1534). Thanks to Gerhard Nachtmann for requesting, and Benjamin Schwendinger for the PR.

```R
N = 1e7
DT = data.table(x=sample(N), y=sample(1e6,N,TRUE))
shift_no_opt = shift # different name not optimised as a way to compare
microbenchmark(
DT[, c(NA, head(x,-1)), y],
DT[, shift_no_opt(x, 1, type="lag"), y],
DT[, shift(x, 1, type="lag"), y],
times=10L, unit="s")
# Unit: seconds
# expr min lq mean median uq max neval
# DT[, c(NA, head(x, -1)), y] 8.7620 9.0240 9.1870 9.2800 9.3700 9.4110 10
# DT[, shift_no_opt(x, 1, type = "lag"), y] 20.5500 20.9000 21.1600 21.3200 21.4400 21.5200 10
# DT[, shift(x, 1, type = "lag"), y] 0.4865 0.5238 0.5463 0.5446 0.5725 0.5982 10
```

Example from [stackoverflow](https://stackoverflow.com/questions/35179911/shift-in-data-table-v1-9-6-is-slow-for-many-groups)
```R
set.seed(1)
mg = data.table(expand.grid(year=2012:2016, id=1:1000),
value=rnorm(5000))
microbenchmark(v1.9.4 = mg[, c(value[-1], NA), by=id],
v1.9.6 = mg[, shift_no_opt(value, n=1, type="lead"), by=id],
v1.14.4 = mg[, shift(value, n=1, type="lead"), by=id],
unit="ms")
# Unit: milliseconds
# expr min lq mean median uq max neval
# v1.9.4 3.6600 3.8250 4.4930 4.1720 4.9490 11.700 100
# v1.9.6 18.5400 19.1800 21.5100 20.6900 23.4200 29.040 100
# v1.14.4 0.4826 0.5586 0.6586 0.6329 0.7348 1.318 100
```

## BUG FIXES

1. `by=.EACHI` when `i` is keyed but `on=` different columns than `i`'s key could create an invalidly keyed result, [#4603](https://github.com/Rdatatable/data.table/issues/4603) [#4911](https://github.com/Rdatatable/data.table/issues/4911). Thanks to @myoung3 and @adamaltmejd for reporting, and @ColeMiller1 for the PR. An invalid key is where a `data.table` is marked as sorted by the key columns but the data is not sorted by those columns, leading to incorrect results from subsequent queries.
Expand Down
24 changes: 23 additions & 1 deletion R/data.table.R
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,10 @@ replace_dot_alias = function(e) {
if (!(is.call(q) && is.symbol(q[[1L]]) && is.symbol(q[[2L]]) && (q1 <- q[[1L]]) %chin% gfuns)) return(FALSE)
if (!(q2 <- q[[2L]]) %chin% names(SDenv$.SDall) && q2 != ".I") return(FALSE) # 875
if ((length(q)==2L || (!is.null(names(q)) && startsWith(names(q)[3L], "na")))) return(TRUE)
if (length(q)>=2L && q[[1L]] == "shift") {
q_named = match.call(shift, q)
if (!is.call(q_named[["fill"]]) && is.null(q_named[["give.names"]])) return(TRUE)
} # add gshift support
# ^^ base::startWith errors on NULL unfortunately
# head-tail uses default value n=6 which as of now should not go gforce ... ^^
# otherwise there must be three arguments, and only in two cases:
Expand Down Expand Up @@ -1848,6 +1852,17 @@ replace_dot_alias = function(e) {
gi = if (length(o__)) o__[f__] else f__
g = lapply(grpcols, function(i) groups[[i]][gi])

# returns all rows instead of one per group
nrow_funs = c("gshift")
.is_nrows = function(q) {
if (!is.call(q)) return(FALSE)
if (q[[1L]] == "list") {
any(vapply(q, .is_nrows, FALSE))
} else {
q[[1L]] %chin% nrow_funs
}
}

# adding ghead/gtail(n) support for n > 1 #5060 #523
q3 = 0
if (!is.symbol(jsub)) {
Expand All @@ -1865,6 +1880,8 @@ replace_dot_alias = function(e) {
if (q3 > 0) {
grplens = pmin.int(q3, len__)
g = lapply(g, rep.int, times=grplens)
} else if (.is_nrows(jsub)) {
g = lapply(g, rep.int, times=len__)
}
ans = c(g, ans)
} else {
Expand Down Expand Up @@ -2970,7 +2987,7 @@ rleidv = function(x, cols=seq_along(x), prefix=NULL) {
# (2) edit .gforce_ok (defined within `[`) to catch which j will apply the new function
# (3) define the gfun = function() R wrapper
gfuns = c("[", "[[", "head", "tail", "first", "last", "sum", "mean", "prod",
"median", "min", "max", "var", "sd", ".N") # added .N for #334
"median", "min", "max", "var", "sd", ".N", "shift") # added .N for #334
`g[` = `g[[` = function(x, n) .Call(Cgnthvalue, x, as.integer(n)) # n is of length=1 here.
ghead = function(x, n) .Call(Cghead, x, as.integer(n)) # n is not used at the moment
gtail = function(x, n) .Call(Cgtail, x, as.integer(n)) # n is not used at the moment
Expand All @@ -2984,6 +3001,11 @@ gmin = function(x, na.rm=FALSE) .Call(Cgmin, x, na.rm)
gmax = function(x, na.rm=FALSE) .Call(Cgmax, x, na.rm)
gvar = function(x, na.rm=FALSE) .Call(Cgvar, x, na.rm)
gsd = function(x, na.rm=FALSE) .Call(Cgsd, x, na.rm)
gshift = function(x, n=1L, fill=NA, type=c("lag", "lead", "shift", "cyclic")) {
type = match.arg(type)
stopifnot(is.numeric(n))
.Call(Cgshift, x, as.integer(n), fill, type)
}
gforce = function(env, jsub, o, f, l, rows) .Call(Cgforce, env, jsub, o, f, l, rows)

.prepareFastSubset = function(isub, x, enclos, notjoin, verbose = FALSE){
Expand Down
Binary file added inst/tests/test2224.Rdata
Binary file not shown.
65 changes: 59 additions & 6 deletions inst/tests/tests.Rraw
Original file line number Diff line number Diff line change
Expand Up @@ -18243,21 +18243,39 @@ test(2217, DT1[, by = grp, .(agg = list(setNames(as.numeric(value), id)))], DT2)
testnum = 2218
funs = c(as.integer, as.double, as.complex, as.character, if (test_bit64) as.integer64)
# when test_bit64==FALSE these all passed before; now passes with test_bit64==TRUE too
# add grouping tests for #5205
g = rep(c(1,2), each=2)
options(datatable.optimize = 2L)
for (f1 in funs) {
DT = data.table(x=f1(1:4))
DT = data.table(x=f1(1:4), g=g)
for (f2 in funs) {
testnum = testnum + 0.01
testnum = testnum + 0.001
test(testnum, DT[, shift(x)], f1(c(NA, 1:3)))
testnum = testnum + 0.01
testnum = testnum + 0.001
w = if (identical(f2,as.character) && !identical(f1,as.character)) "Coercing.*character.*to match the type of target vector"
test(testnum, DT[, shift(x, fill=f2(NA))], f1(c(NA, 1:3)), warning=w)
testnum = testnum + 0.01
testnum = testnum + 0.001
if (identical(f1,as.character) && identical(f2,as.complex)) {
# one special case due to as.complex(0)=="0+0i"!="0"
test(testnum, DT[, shift(x, fill="0")], f1(0:3))
} else {
test(testnum, DT[, shift(x, fill=f2(0))], f1(0:3), warning=w)
}

testnum = testnum + 0.001
test(testnum, DT[, shift(x), by=g], data.table(g=g, V1=f1(c(NA, 1, NA, 3))))
testnum = testnum + 0.001
w = if (identical(f2,as.character) && !identical(f1,as.character)) "Coercing.*character.*to match the type of target vector"
f = f2(NA)
test(testnum, DT[, shift(x, fill=f), by=g], data.table(g=g, V1=f1(c(NA, 1, NA, 3))), warning=w)
testnum = testnum + 0.001
if (identical(f1,as.character) && identical(f2,as.complex)) {
# one special case due to as.complex(0)=="0+0i"!="0"
test(testnum, DT[, shift(x, fill="0"), by=g], data.table(g=g, V1=f1(c(0,1,0,3))))
} else {
f = f2(0)
test(testnum, DT[, shift(x, fill=f), by=g], data.table(g=g, V1=f1(c(0,1,0,3))), warning=w)
}
}
}

Expand Down Expand Up @@ -18292,6 +18310,41 @@ DT = data.table(A=1:3, key="A")
test(2223.1, DT[.(4), nomatch=FALSE], data.table(A=integer(), key="A"))
test(2223.2, DT[.(4), nomatch=NA_character_], data.table(A=4L, key="A"))

# gshift, #5205
options(datatable.optimize = 2L)
set.seed(123)
DT = data.table(x = sample(letters[1:5], 20, TRUE),
y = rep.int(1:2, 10), # to test 2 grouping columns get rep'd properly
i = sample(c(-2L,0L,3L,NA), 20, TRUE),
d = sample(c(1.2,-3.4,5.6,NA), 20, TRUE),
s = sample(c("foo","bar",NA), 20, TRUE),
c = sample(c(0+3i,1,-1-1i,NA), 20, TRUE),
l = sample(c(TRUE, FALSE, NA), 20, TRUE),
r = as.raw(sample(1:5, 20, TRUE)))
load(testDir("test2224.Rdata")) # ans array
if (test_bit64) {
DT[, i64:=as.integer64(sample(c(-2L,0L,2L,NA), 20, TRUE))]
} else {
ans = ans[, -match("i64",colnames(ans))]
}
test(2224.01, sapply(names(DT)[-1], function(col) {
sapply(list(1, 5, -1, -5, c(1,2), c(-1,1)), function(n) list(
# fill is tested by group in tests 2218.*; see comments in #5205
EVAL(sprintf("DT[, shift(%s, %d, type='lag'), by=x]$V1", col, n)),
EVAL(sprintf("DT[, shift(%s, %d, type='lead'), by=x]$V1", col, n)),
EVAL(sprintf("DT[, shift(%s, %d, type='shift'), by=x]$V1", col, n)),
EVAL(sprintf("DT[, shift(%s, %d, type='cyclic'), by=x]$V1", col, n))
))
}), ans)
a = 1:2 # fill argument with length > 1 which is not a call
test(2224.02, DT[, shift(i, fill=a), by=x], error="fill must be a vector of length 1")
DT = data.table(x=pairlist(1), g=1)
# unsupported type as argument
test(2224.03, DT[, shift(x), g], error="Type 'list' is not supported by GForce gshift.")

# groupingsets by named by argument
test(2224.1, groupingsets(data.table(iris), j = sum(Sepal.Length), by = c('Sp'='Species'), sets = list('Species')), data.table(Species = factor(c("setosa", "versicolor", "virginica")), V1=c(250.3, 296.8, 329.4)))
test(2224.2, groupingsets(data.table(iris), j = mean(Sepal.Length), by = c('Sp'='Species'), sets = list('Species')), groupingsets(data.table(iris), j = mean(Sepal.Length), by = c('Species'), sets = list('Species')))
test(2225.1, groupingsets(data.table(iris), j=sum(Sepal.Length), by=c('Sp'='Species'), sets=list('Species')),
data.table(Species=factor(c("setosa", "versicolor", "virginica")), V1=c(250.3, 296.8, 329.4)))
test(2225.2, groupingsets(data.table(iris), j=mean(Sepal.Length), by=c('Sp'='Species'), sets=list('Species')),
groupingsets(data.table(iris), j=mean(Sepal.Length), by=c('Species'), sets=list('Species')))

90 changes: 90 additions & 0 deletions src/gsumm.c
Original file line number Diff line number Diff line change
Expand Up @@ -1162,3 +1162,93 @@ SEXP gprod(SEXP x, SEXP narmArg) {
return(ans);
}

SEXP gshift(SEXP x, SEXP nArg, SEXP fillArg, SEXP typeArg) {
const bool nosubset = irowslen == -1;
const bool issorted = !isunsorted;
const int n = nosubset ? length(x) : irowslen;
if (nrow != n) error(_("Internal error: nrow [%d] != length(x) [%d] in %s"), nrow, n, "gshift");

int nprotect=0;
enum {LAG, LEAD/*, SHIFT*/,CYCLIC} stype = LAG;
if (!(length(fillArg) == 1))
error(_("fill must be a vector of length 1"));

if (!isString(typeArg) || length(typeArg) != 1)
error(_("Internal error: invalid type for gshift(), should have been caught before. please report to data.table issue tracker")); // # nocov
if (!strcmp(CHAR(STRING_ELT(typeArg, 0)), "lag")) stype = LAG;
else if (!strcmp(CHAR(STRING_ELT(typeArg, 0)), "lead")) stype = LEAD;
else if (!strcmp(CHAR(STRING_ELT(typeArg, 0)), "shift")) stype = LAG;
else if (!strcmp(CHAR(STRING_ELT(typeArg, 0)), "cyclic")) stype = CYCLIC;
else error(_("Internal error: invalid type for gshift(), should have been caught before. please report to data.table issue tracker")); // # nocov

bool lag;
const bool cycle = stype == CYCLIC;

R_xlen_t nx = xlength(x), nk = length(nArg);
if (!isInteger(nArg)) error(_("Internal error: n must be integer")); // # nocov
const int *kd = INTEGER(nArg);
for (int i=0; i<nk; i++) if (kd[i]==NA_INTEGER) error(_("Item %d of n is NA"), i+1);

SEXP ans = PROTECT(allocVector(VECSXP, nk)); nprotect++;
SEXP thisfill = PROTECT(coerceAs(fillArg, x, ScalarLogical(0))); nprotect++;
for (int g=0; g<nk; g++) {
lag = stype == LAG || stype == CYCLIC;
int m = kd[g];
// switch
if (m < 0) {
m = m * (-1);
lag = !lag;
}
R_xlen_t ansi = 0;
SEXP tmp;
SET_VECTOR_ELT(ans, g, tmp=allocVector(TYPEOF(x), nx));
#define SHIFT(CTYPE, RTYPE, ASSIGN) { \
const CTYPE *xd = (const CTYPE *)RTYPE(x); \
const CTYPE fill = RTYPE(thisfill)[0]; \
for (int i=0; i<ngrp; ++i) { \
const int grpn = grpsize[i]; \
const int mg = cycle ? (((m-1) % grpn) + 1) : m; \
const int thisn = MIN(mg, grpn); \
const int jstart = ff[i]-1+ (!lag)*(thisn); \
const int jend = jstart+ MAX(0, grpn-mg); /*if m > grpn -> jend = jstart */ \
if (lag) { \
const int o = ff[i]-1+(grpn-thisn); \
for (int j=0; j<thisn; ++j) { \
const int k = issorted ? (o+j) : oo[o+j]-1; \
const CTYPE val = cycle ? (nosubset ? xd[k] : (irows[k]==NA_INTEGER ? fill : xd[irows[k]-1])) : fill; \
ASSIGN; \
} \
} \
for (int j=jstart; j<jend; ++j) { \
const int k = issorted ? j : oo[j]-1; \
const CTYPE val = nosubset ? xd[k] : (irows[k]==NA_INTEGER ? fill : xd[irows[k]-1]); \
ASSIGN; \
} \
if (!lag) { \
const int o = ff[i]-1; \
for (int j=0; j<thisn; ++j) { \
const int k = issorted ? (o+j) : oo[o+j]-1; \
const CTYPE val = cycle ? (nosubset ? xd[k] : (irows[k]==NA_INTEGER ? fill : xd[irows[k]-1])) : fill; \
ASSIGN; \
} \
} \
} \
}
switch(TYPEOF(x)) {
case RAWSXP: { Rbyte *ansd=RAW(tmp); SHIFT(Rbyte, RAW, ansd[ansi++]=val); } break;
case LGLSXP: { int *ansd=LOGICAL(tmp); SHIFT(int, LOGICAL, ansd[ansi++]=val); } break;
case INTSXP: { int *ansd=INTEGER(tmp); SHIFT(int, INTEGER, ansd[ansi++]=val); } break;
case REALSXP: { double *ansd=REAL(tmp); SHIFT(double, REAL, ansd[ansi++]=val); } break;
// integer64 is shifted as if it's REAL; and assigning fill=NA_INTEGER64 is ok as REAL
case CPLXSXP: { Rcomplex *ansd=COMPLEX(tmp); SHIFT(Rcomplex, COMPLEX, ansd[ansi++]=val); } break;
case STRSXP: { SHIFT(SEXP, STRING_PTR, SET_STRING_ELT(tmp,ansi++,val)); } break;
//case VECSXP: { SHIFT(SEXP, SEXPPTR_RO, SET_VECTOR_ELT(tmp,ansi++,val)); } break;
default:
error(_("Type '%s' is not supported by GForce gshift. Either add the namespace prefix (e.g. data.table::shift(.)) or turn off GForce optimization using options(datatable.optimize=1)"), type2char(TYPEOF(x)));
}
copyMostAttrib(x, tmp); // needed for integer64 because without not the correct class of int64 is assigned
}
UNPROTECT(nprotect);
return(ans);
}

2 changes: 2 additions & 0 deletions src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ SEXP dim();
SEXP gvar();
SEXP gsd();
SEXP gprod();
SEXP gshift();
SEXP nestedid();
SEXP setDTthreads();
SEXP getDTthreads_R();
Expand Down Expand Up @@ -197,6 +198,7 @@ R_CallMethodDef callMethods[] = {
{"Cgvar", (DL_FUNC) &gvar, -1},
{"Cgsd", (DL_FUNC) &gsd, -1},
{"Cgprod", (DL_FUNC) &gprod, -1},
{"Cgshift", (DL_FUNC) &gshift, -1},
{"Cnestedid", (DL_FUNC) &nestedid, -1},
{"CsetDTthreads", (DL_FUNC) &setDTthreads, -1},
{"CgetDTthreads", (DL_FUNC) &getDTthreads_R, -1},
Expand Down

0 comments on commit e88826e

Please sign in to comment.