diff --git a/src/ensembles.jl b/src/ensembles.jl index 122a87b..28cb2e0 100644 --- a/src/ensembles.jl +++ b/src/ensembles.jl @@ -200,7 +200,6 @@ _reducer(p, q) = vcat(p, q) _reducer(p::Tuple, q::Tuple) = (vcat(p[1], q[1]), vcat(p[2], q[2])) - # # ENSEMBLE MODEL TYPES mutable struct DeterministicEnsembleModel{Atom<:Deterministic} <: Deterministic @@ -638,11 +637,8 @@ end # Note: input and target traits are inherited from atom -MMI.load_path(::Type{<:ProbabilisticEnsembleModel}) = - "MLJ.ProbabilisticEnsembleModel" -MMI.load_path(::Type{<:DeterministicEnsembleModel}) = - "MLJ.DeterministicEnsembleModel" - +MMI.load_path(::Type{<:EitherEnsembleModel}) = "MLJEnsembles.EnsembleModel" +MMI.constructor(::Type{<:EitherEnsembleModel}) = EnsembleModel MMI.is_wrapper(::Type{<:EitherEnsembleModel}) = true MMI.supports_weights(::Type{<:EitherEnsembleModel{Atom}}) where Atom = MMI.supports_weights(Atom) diff --git a/test/ensembles.jl b/test/ensembles.jl index f371dde..1cb1dfc 100644 --- a/test/ensembles.jl +++ b/test/ensembles.jl @@ -63,6 +63,9 @@ X = MLJEnsembles.table(ones(5,3)) y = categorical(collect("asdfa")) train, test = partition(1:length(y), 0.8); ensemble_model = EnsembleModel(model=atom) +@test constructor(ensemble_model) == EnsembleModel +@test load_path(ensemble_model) == "MLJEnsembles.EnsembleModel" +@test package_name(ensemble_model) == "MLJEnsembles" ensemble_model.n = 10 fitresult, cache, report = MLJEnsembles.fit(ensemble_model, 0, X, y) predict(ensemble_model, fitresult, MLJEnsembles.selectrows(X, test))