Skip to content

Commit

Permalink
Online optimization for Agents (#172)
Browse files Browse the repository at this point in the history
* fix

* fix#2

* minor

* online optimization

* reverse to offline learning

* add check dim

* docstring

* minor

* change learning policy

* minors

* minor

* restore

* delete useless imports
  • Loading branch information
maypink authored Sep 4, 2023
1 parent b84404f commit d5d1dfa
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 13 deletions.
3 changes: 1 addition & 2 deletions examples/adaptive_optimizer/experiment_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List, Sequence, Optional, Dict

import networkx as nx
import numpy as np
from matplotlib import pyplot as plt
from sklearn.cluster import KMeans

Expand Down Expand Up @@ -65,7 +64,7 @@ def run_adaptive_mutations_with_context(

def log_action_values_with_clusters(next_pop: PopulationT, optimizer: EvoGraphOptimizer):
obs_contexts = optimizer.mutation.agent.get_context(next_pop)
cluster.fit(np.array(obs_contexts).reshape(-1, 1))
cluster.fit(obs_contexts.reshape(-1, 1))
centers = cluster.cluster_centers_
for i, center in enumerate(centers):
values = optimizer.mutation.agent.get_action_values(obs=center)
Expand Down
3 changes: 1 addition & 2 deletions experiments/mab/mab_synthetic_experiment_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import partial
from pprint import pprint

import numpy as np
import pandas as pd
import seaborn as sns

Expand Down Expand Up @@ -76,7 +75,7 @@ def log_action_values(next_pop: PopulationT, optimizer: EvoGraphOptimizer):

def log_action_values_with_clusters(next_pop: PopulationT, optimizer: EvoGraphOptimizer):
obs_contexts = optimizer.mutation.agent.get_context(next_pop)
self.cluster.partial_fit(np.array(obs_contexts))
self.cluster.partial_fit(obs_contexts)
centers = self.cluster.cluster_centers_
for i, center in enumerate(centers):
values = optimizer.mutation.agent.get_action_values(obs=[center])
Expand Down
21 changes: 12 additions & 9 deletions golem/core/optimisers/adaptive/mab_agents/contextual_mab_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ContextualMultiArmedBanditAgent(OperatorAgent):
using NN to guarantee convergence.
:param actions: types of mutations
:param context_agent: function to convert observation to its embedding. Can be specified as
:param context_agent_type: function to convert observation to its embedding. Can be specified as
ContextAgentTypeEnum or as Callable function.
:param available_operations: available operations
:param n_jobs: n_jobs
Expand Down Expand Up @@ -50,22 +50,22 @@ def _initial_fit(self, obs: ObsType):
n = len(self._indices)
uniform_rewards = [1. / n] * n
contexts = self.get_context(obs=obs)
self._agent.fit(decisions=self._indices, rewards=uniform_rewards, contexts=contexts * n)
self._agent.fit(decisions=self._indices, rewards=uniform_rewards, contexts=np.tile(contexts, (n, 1)))
self._is_fitted = True

def choose_action(self, obs: ObsType) -> ActType:
if not self._is_fitted:
self._initial_fit(obs=obs)
contexts = self.get_context(obs=obs)
arm = self._agent.predict(contexts=np.array(contexts).reshape(1, -1))
arm = self._agent.predict(contexts=contexts.reshape(1, -1))
action = self.actions[arm]
return action

def get_action_values(self, obs: Optional[ObsType] = None) -> Sequence[float]:
if not self._is_fitted:
self._initial_fit(obs=obs)
contexts = self.get_context(obs)
prob_dict = self._agent.predict_expectations(contexts=np.array(contexts).reshape(1, -1))
prob_dict = self._agent.predict_expectations(contexts=contexts.reshape(1, -1))
prob_list = [prob_dict[i] for i in range(len(prob_dict))]
return prob_list

Expand All @@ -84,14 +84,17 @@ def partial_fit(self, experience: ExperienceBuffer):
contexts = self.get_context(obs=obs)
self._agent.partial_fit(decisions=arms, rewards=rewards, contexts=contexts)

def get_context(self, obs: Union[List[ObsType], ObsType]) -> List[List[float]]:
def get_context(self, obs: Union[List[ObsType], ObsType]) -> np.array:
""" Returns contexts based on specified context agent. """
if not isinstance(obs, list):
return self._context_agent(obs)
return np.array(self._context_agent(obs)).flatten()
contexts = []
for ob in obs:
if isinstance(ob, list) or isinstance(ob, np.ndarray):
contexts.append(ob)
# to unify type to list
contexts.append(np.array(ob).flatten())
else:
contexts.append(self._context_agent(ob))
return contexts
context = np.array(self._context_agent(ob))
# some external context agents can wrap context in an additional array
contexts.append(context.flatten())
return np.array(contexts)
3 changes: 3 additions & 0 deletions golem/core/optimisers/genetic/operators/mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def _init_operator_agent(graph_gen_params: GraphGenerationParams,
agent = NeuralContextualMultiArmedBanditAgent(actions=parameters.mutation_types,
context_agent_type=parameters.context_agent_type,
n_jobs=requirements.n_jobs)
# if agent was specified pretrained (with instance)
elif isinstance(parameters.adaptive_mutation_type, OperatorAgent):
agent = kind
else:
raise TypeError(f'Unknown parameter {kind}')
return agent
Expand Down
2 changes: 2 additions & 0 deletions golem/core/optimisers/genetic/operators/reproduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def reproduce_uncontrolled(self,
# It can be faster if it could.
selected_individuals = self.selection(population, pop_size)
new_population = self.crossover(selected_individuals)
new_population = self.mutation(new_population)

new_population = ensure_wrapped_in_sequence(self.mutation(new_population))
new_population = evaluator(new_population)
return new_population
Expand Down

0 comments on commit d5d1dfa

Please sign in to comment.