package proxy
import (
"sync"
"testing"
"time"
)
const (
testMsgExpectedError = "expected error"
testMsgUnexpectedNilError = "unexpected nil error"
testMsgWrongError = "wrong error message"
)
func TestDial_TransportStopped_InitialCheck(t *testing.T) {
tr := newTransport("test_initial_stop", "127.0.0.1:0")
tr.Start()
tr.Stop()
time.Sleep(50 * time.Millisecond)
_, _, err := tr.Dial("udp")
if err == nil {
t.Fatalf("%s: %s", testMsgExpectedError, testMsgUnexpectedNilError)
}
if err.Error() != ErrTransportStopped {
t.Errorf("%s: got '%v', want '%s'", testMsgWrongError, err, ErrTransportStopped)
}
}
func TestDial_TransportStoppedDuringDialSend(t *testing.T) {
tr := newTransport("test_during_dial_send", "127.0.0.1:0")
dialErrChan := make(chan error, 1)
go func() {
_, _, err := tr.Dial("udp")
dialErrChan <- err
}()
time.Sleep(50 * time.Millisecond)
tr.Stop()
err := <-dialErrChan
if err == nil {
t.Fatalf("%s: %s", testMsgExpectedError, testMsgUnexpectedNilError)
}
if err.Error() != ErrTransportStoppedDuringDial {
t.Errorf("%s: got '%v', want '%s'", testMsgWrongError, err, ErrTransportStoppedDuringDial)
}
}
func TestDial_TransportStoppedDuringRetWait(t *testing.T) {
tr := newTransport("test_during_ret_wait", "127.0.0.1:0")
tr.dial = make(chan string)
tr.ret = make(chan *persistConn)
tr.Start()
dialErrChan := make(chan error, 1)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
_, _, err := tr.Dial("udp")
dialErrChan <- err
}()
var protoFromDial string
select {
case protoFromDial = <-tr.dial:
t.Logf("Test: Simulated connManager read '%s' from Dial via test-controlled tr.dial", protoFromDial)
case <-time.After(500 * time.Millisecond):
t.Fatal("Test: Timeout waiting for Dial to send on test-controlled tr.dial")
}
tr.Stop()
wg.Wait()
err := <-dialErrChan
if err == nil {
t.Fatalf("%s: %s", testMsgExpectedError, testMsgUnexpectedNilError)
}
if err.Error() != ErrTransportStoppedDuringRetWait {
t.Errorf("%s: got '%v', want '%s' (or potentially '%s' if timing is very tight)",
testMsgWrongError, err, ErrTransportStoppedDuringRetWait, ErrTransportStopped)
} else {
t.Logf("SUCCESS: Dial correctly returned '%s'", ErrTransportStoppedDuringRetWait)
}
}
func TestDial_Returns_ErrTransportStoppedRetClosed(t *testing.T) {
tr := newTransport("test_returns_ret_closed", "127.0.0.1:0")
testDialChan := make(chan string, 1)
testRetChan := make(chan *persistConn)
tr.dial = testDialChan
tr.ret = testRetChan
dialErrChan := make(chan error, 1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_, _, err := tr.Dial("udp")
dialErrChan <- err
}()
select {
case proto := <-testDialChan:
if proto != "udp" {
wg.Done()
t.Fatalf("Test: Dial sent wrong proto on testDialChan: got %s, want udp", proto)
}
t.Logf("Test: Simulated connManager received '%s' from Dial via testDialChan.", proto)
case <-time.After(500 * time.Millisecond):
wg.Done()
t.Fatal("Test: Timeout waiting for Dial to send on testDialChan.")
}
close(testRetChan)
t.Logf("Test: Closed testRetChan (simulating connManager closing tr.ret).")
wg.Wait()
err := <-dialErrChan
if err == nil {
t.Fatalf("%s: %s", testMsgExpectedError, testMsgUnexpectedNilError)
}
if err.Error() != ErrTransportStoppedRetClosed {
t.Errorf("%s: got '%v', want '%s'", testMsgWrongError, err, ErrTransportStoppedRetClosed)
} else {
t.Logf("SUCCESS: Dial correctly returned '%s'", ErrTransportStoppedRetClosed)
}
tr.Stop()
}
func TestDial_ConnManagerClosesRetOnStop(t *testing.T) {
tr := newTransport("test_connmanager_closes_ret", "127.0.0.1:0")
tr.Start()
interactionDialErrChan := make(chan error, 1)
go func() {
_, _, err := tr.Dial("udp")
interactionDialErrChan <- err
}()
time.Sleep(100 * time.Millisecond)
tr.Stop()
time.Sleep(50 * time.Millisecond)
select {
case _, ok := <-tr.ret:
if !ok {
t.Logf("SUCCESS: tr.ret channel is closed as expected after transport stop.")
} else {
t.Errorf("FAIL: tr.ret channel was not closed after transport stop, or a value was read unexpectedly.")
}
default:
t.Errorf("FAIL: tr.ret channel is not closed and is blocking (or empty but open).")
}
select {
case err := <-interactionDialErrChan:
if err != nil {
t.Logf("Interaction Dial completed with error (possibly expected due to 127.0.0.1:0 or race with Stop): %v", err)
} else {
t.Logf("Interaction Dial completed without error.")
}
case <-time.After(500 * time.Millisecond):
t.Logf("Timeout waiting for interaction Dial to complete.")
}
}
func TestDial_MultipleCallsAfterStop(t *testing.T) {
tr := newTransport("test_multiple_after_stop", "127.0.0.1:0")
tr.Start()
tr.Stop()
time.Sleep(50 * time.Millisecond)
for i := range 3 {
_, _, err := tr.Dial("udp")
if err == nil {
t.Errorf("Attempt %d: %s: %s", i+1, testMsgExpectedError, testMsgUnexpectedNilError)
continue
}
if err.Error() != ErrTransportStopped {
t.Errorf("Attempt %d: %s: got '%v', want '%s'", i+1, testMsgWrongError, err, ErrTransportStopped)
}
}
}