diff --git a/agent.go b/agent.go index eb66de82..2ed28790 100644 --- a/agent.go +++ b/agent.go @@ -1196,8 +1196,10 @@ func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error { // Restart restarts the ICE Agent with the provided ufrag/pwd // If no ufrag/pwd is provided the Agent will generate one itself // -// Restart must only be called when GatheringState is GatheringStateComplete -// a user must then call GatherCandidates explicitly to start generating new ones +// If there is a gatherer routine currently running, Restart will +// cancel it. +// After a Restart, the user must then call GatherCandidates explicitly +// to start generating new ones. func (a *Agent) Restart(ufrag, pwd string) error { if ufrag == "" { var err error @@ -1224,8 +1226,7 @@ func (a *Agent) Restart(ufrag, pwd string) error { var err error if runErr := a.run(a.context(), func(ctx context.Context, agent *Agent) { if agent.gatheringState == GatheringStateGathering { - err = ErrRestartWhenGathering - return + agent.gatherCandidateCancel() } // Clear all agent needed to take back to fresh state diff --git a/agent_test.go b/agent_test.go index 598db5c1..84c273cf 100644 --- a/agent_test.go +++ b/agent_test.go @@ -1375,13 +1375,24 @@ func TestAgentRestart(t *testing.T) { oneSecond := time.Second t.Run("Restart During Gather", func(t *testing.T) { - agent, err := NewAgent(&AgentConfig{}) - assert.NoError(t, err) + connA, connB := pipe(&AgentConfig{ + DisconnectedTimeout: &oneSecond, + FailedTimeout: &oneSecond, + }) + + ctx, cancel := context.WithCancel(context.Background()) + assert.NoError(t, connB.agent.OnConnectionStateChange(func(c ConnectionState) { + if c == ConnectionStateFailed || c == ConnectionStateDisconnected { + cancel() + } + })) - agent.gatheringState = GatheringStateGathering + connA.agent.gatheringState = GatheringStateGathering + assert.NoError(t, connA.agent.Restart("", "")) - assert.Equal(t, ErrRestartWhenGathering, agent.Restart("", "")) - assert.NoError(t, agent.Close()) + <-ctx.Done() + assert.NoError(t, connA.agent.Close()) + assert.NoError(t, connB.agent.Close()) }) t.Run("Restart When Closed", func(t *testing.T) { diff --git a/errors.go b/errors.go index 838e714c..dbda8bd7 100644 --- a/errors.go +++ b/errors.go @@ -97,9 +97,6 @@ var ( // ErrInvalidMulticastDNSHostName indicates an invalid MulticastDNSHostName ErrInvalidMulticastDNSHostName = errors.New("invalid mDNS HostName, must end with .local and can only contain a single '.'") - // ErrRestartWhenGathering indicates Restart was called when Agent is in GatheringStateGathering - ErrRestartWhenGathering = errors.New("ICE Agent can not be restarted when gathering") - // ErrRunCanceled indicates a run operation was canceled by its individual done ErrRunCanceled = errors.New("run was canceled by done")