Skip to content

Commit

Permalink
Allow feature names to be supplied for outputs (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
trevorstephens authored Dec 18, 2018
1 parent 3e21f52 commit 07b41a1
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 7 deletions.
22 changes: 18 additions & 4 deletions gplearn/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ class _Program(object):
The reason for this being passed is that during parallel evolution the
same program object may be accessed by multiple parallel processes.
feature_names : list, optional (default=None)
Optional list of feature names, used purely for representations in
the `print` operation or `export_graphviz`. If None, then X0, X1, etc
will be used for representations.
program : list, optional (default=None)
The flattened tree representation of the program. If None, a new naive
random tree will be grown. If provided, it will be validated.
Expand Down Expand Up @@ -122,6 +127,7 @@ def __init__(self,
p_point_replace,
parsimony_coefficient,
random_state,
feature_names=None,
program=None):

self.function_set = function_set
Expand All @@ -133,11 +139,12 @@ def __init__(self,
self.metric = metric
self.p_point_replace = p_point_replace
self.parsimony_coefficient = parsimony_coefficient
self.feature_names = feature_names
self.program = program

if self.program is not None:
if not self.validate_program():
raise ValueError('The supplied program is incomplete')
raise ValueError('The supplied program is incomplete.')
else:
# Create a naive random program
self.program = self.build_program(random_state)
Expand Down Expand Up @@ -232,7 +239,10 @@ def __str__(self):
output += node.name + '('
else:
if isinstance(node, int):
output += 'X%s' % node
if self.feature_names is None:
output += 'X%s' % node
else:
output += self.feature_names[node]
else:
output += '%.3f' % node
terminals[-1] -= 1
Expand Down Expand Up @@ -275,8 +285,12 @@ def export_graphviz(self, fade_nodes=None):
if i not in fade_nodes:
fill = '#60a6f6'
if isinstance(node, int):
output += ('%d [label="%s%s", fillcolor="%s"] ;\n'
% (i, 'X', node, fill))
if self.feature_names is None:
feature_name = 'X%s' % node
else:
feature_name = self.feature_names[node]
output += ('%d [label="%s", fillcolor="%s"] ;\n'
% (i, feature_name, fill))
else:
output += ('%d [label="%.3f", fillcolor="%s"] ;\n'
% (i, node, fill))
Expand Down
28 changes: 28 additions & 0 deletions gplearn/genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _parallel_evolve(n_programs, parents, X, y, sample_weight, seeds, params):
method_probs = params['method_probs']
p_point_replace = params['p_point_replace']
max_samples = params['max_samples']
feature_names = params['feature_names']

max_samples = int(max_samples * n_samples)

Expand Down Expand Up @@ -118,6 +119,7 @@ def _tournament():
const_range=const_range,
p_point_replace=p_point_replace,
parsimony_coefficient=parsimony_coefficient,
feature_names=feature_names,
random_state=random_state,
program=program)

Expand Down Expand Up @@ -176,6 +178,7 @@ def __init__(self,
p_point_mutation=0.01,
p_point_replace=0.05,
max_samples=1.0,
feature_names=None,
warm_start=False,
low_memory=False,
n_jobs=1,
Expand All @@ -200,6 +203,7 @@ def __init__(self,
self.p_point_mutation = p_point_mutation
self.p_point_replace = p_point_replace
self.max_samples = max_samples
self.feature_names = feature_names
self.warm_start = warm_start
self.low_memory = low_memory
self.n_jobs = n_jobs
Expand Down Expand Up @@ -355,6 +359,16 @@ def fit(self, X, y, sample_weight=None):
raise ValueError('init_depth should be in increasing numerical '
'order: (min_depth, max_depth).')

if self.feature_names is not None:
if self.n_features_ != len(self.feature_names):
raise ValueError('The supplied `feature_names` has different '
'length to n_features. Expected %d, got %d.'
% (self.n_features_, len(self.feature_names)))
for feature_name in self.feature_names:
if not isinstance(feature_name, six.string_types):
raise ValueError('invalid type %s found in '
'`feature_names`.' % type(feature_name))

params = self.get_params()
params['_metric'] = self._metric
params['function_set'] = self._function_set
Expand Down Expand Up @@ -667,6 +681,11 @@ class SymbolicRegressor(BaseSymbolic, RegressorMixin):
max_samples : float, optional (default=1.0)
The fraction of samples to draw from X to evaluate each program on.
feature_names : list, optional (default=None)
Optional list of feature names, used purely for representations in
the `print` operation or `export_graphviz`. If None, then X0, X1, etc
will be used for representations.
warm_start : bool, optional (default=False)
When set to ``True``, reuse the solution of the previous call to fit
and add more generations to the evolution, otherwise, just fit a new
Expand Down Expand Up @@ -733,6 +752,7 @@ def __init__(self,
p_point_mutation=0.01,
p_point_replace=0.05,
max_samples=1.0,
feature_names=None,
warm_start=False,
low_memory=False,
n_jobs=1,
Expand All @@ -755,6 +775,7 @@ def __init__(self,
p_point_mutation=p_point_mutation,
p_point_replace=p_point_replace,
max_samples=max_samples,
feature_names=feature_names,
warm_start=warm_start,
low_memory=low_memory,
n_jobs=n_jobs,
Expand Down Expand Up @@ -944,6 +965,11 @@ class SymbolicTransformer(BaseSymbolic, TransformerMixin):
max_samples : float, optional (default=1.0)
The fraction of samples to draw from X to evaluate each program on.
feature_names : list, optional (default=None)
Optional list of feature names, used purely for representations in
the `print` operation or `export_graphviz`. If None, then X0, X1, etc
will be used for representations.
warm_start : bool, optional (default=False)
When set to ``True``, reuse the solution of the previous call to fit
and add more generations to the evolution, otherwise, just fit a new
Expand Down Expand Up @@ -1012,6 +1038,7 @@ def __init__(self,
p_point_mutation=0.01,
p_point_replace=0.05,
max_samples=1.0,
feature_names=None,
warm_start=False,
low_memory=False,
n_jobs=1,
Expand All @@ -1036,6 +1063,7 @@ def __init__(self,
p_point_mutation=p_point_mutation,
p_point_replace=p_point_replace,
max_samples=max_samples,
feature_names=feature_names,
warm_start=warm_start,
low_memory=low_memory,
n_jobs=n_jobs,
Expand Down
46 changes: 43 additions & 3 deletions gplearn/tests/test_genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,17 @@ def test_validate_program():
# This one should be fine
_ = _Program(function_set, arities, init_depth, init_method, n_features,
const_range, metric, p_point_replace, parsimony_coefficient,
random_state, test_gp)
random_state, None, test_gp)

# Now try a couple that shouldn't be
assert_raises(ValueError, _Program, function_set, arities, init_depth,
init_method, n_features, const_range, metric,
p_point_replace, parsimony_coefficient, random_state,
test_gp[:-1])
None, test_gp[:-1])
assert_raises(ValueError, _Program, function_set, arities, init_depth,
init_method, n_features, const_range, metric,
p_point_replace, parsimony_coefficient, random_state,
test_gp + [1])
None, test_gp + [1])


def test_print_overloading():
Expand Down Expand Up @@ -222,6 +222,22 @@ def test_print_overloading():
lisp = "mul(div(X8, X1), sub(X9, 0.500))"
assert_true(output == lisp)

# Test with feature names
params['feature_names'] = [str(n) for n in range(10)]
gp = _Program(random_state=random_state, program=test_gp, **params)

orig_stdout = sys.stdout
try:
out = StringIO()
sys.stdout = out
print(gp)
output = out.getvalue().strip()
finally:
sys.stdout = orig_stdout

lisp = "mul(div(8, 1), sub(9, 0.500))"
assert_true(output == lisp)


def test_export_graphviz():
"""Check output of a simple program to Graphviz"""
Expand Down Expand Up @@ -253,7 +269,16 @@ def test_export_graphviz():
'4 -> 6 ;\n4 -> 5 ;\n0 -> 4 ;\n0 -> 1 ;\n}'
assert_true(output == tree)

# Test with feature names
params['feature_names'] = [str(n) for n in range(10)]
gp = _Program(random_state=random_state, program=test_gp, **params)
output = gp.export_graphviz()
tree = tree.replace('X', '')
assert_true(output == tree)

# Test with fade_nodes
params['feature_names'] = None
gp = _Program(random_state=random_state, program=test_gp, **params)
output = gp.export_graphviz(fade_nodes=[0, 1, 2, 3])
tree = 'digraph program {\n' \
'node [style=filled]0 [label="mul", fillcolor="#cecece"] ;\n' \
Expand All @@ -276,6 +301,21 @@ def test_export_graphviz():
assert_true(output == tree)


def test_invalid_feature_names():
"""Check invalid feature names raise errors"""

for Symbolic in (SymbolicRegressor, SymbolicTransformer):

# Check invalid length feature_names
est = Symbolic(feature_names=['foo', 'bar'])
assert_raises(ValueError, est.fit, boston.data, boston.target)

# Check invalid type feature_name
feature_names = [str(n) for n in range(12)] + [0]
est = Symbolic(feature_names=feature_names)
assert_raises(ValueError, est.fit, boston.data, boston.target)


def test_execute():
"""Check executing the program works"""

Expand Down

0 comments on commit 07b41a1

Please sign in to comment.