From 9f7d11a3bc98886532506df210bbc7f4e9c4aff2 Mon Sep 17 00:00:00 2001 From: Zach Klippenstein Date: Sun, 10 Jan 2016 13:33:22 -0800 Subject: [PATCH] Refactored ClientConfig into a Server interface. * StartServer is now a method on Server. * What used to be Dialer.Dial is now Server.Dial. * Server.Dial handles trying to start the server if the initial connection fails. * Dialer now dials a network address. * All types that took a Dialer now take a Server. * Server now has tests! --- client_config.go | 16 ----- cmd/adb/main.go | 17 +++-- cmd/demo/demo.go | 23 +++++-- cmd/raw-adb/raw-adb.go | 9 ++- device_client.go | 12 ++-- device_client_test.go | 22 ++----- device_watcher.go | 22 +++---- device_watcher_test.go | 31 ++------- dialer.go | 63 ++---------------- host_client.go | 14 ++-- host_client_test.go | 108 +----------------------------- server.go | 146 +++++++++++++++++++++++++++++++++++++++++ server_controller.go | 18 ----- server_mock_test.go | 117 +++++++++++++++++++++++++++++++++ server_test.go | 80 ++++++++++++++++++++++ 15 files changed, 420 insertions(+), 278 deletions(-) delete mode 100644 client_config.go create mode 100644 server.go delete mode 100644 server_controller.go create mode 100644 server_mock_test.go create mode 100644 server_test.go diff --git a/client_config.go b/client_config.go deleted file mode 100644 index b6867fe..0000000 --- a/client_config.go +++ /dev/null @@ -1,16 +0,0 @@ -package goadb - -var ( - defaultDialer Dialer = NewDialer("", 0) -) - -type ClientConfig struct { - Dialer Dialer -} - -func (c ClientConfig) sanitized() ClientConfig { - if c.Dialer == nil { - c.Dialer = defaultDialer - } - return c -} diff --git a/cmd/adb/main.go b/cmd/adb/main.go index 5b3c4ec..1eb92d8 100644 --- a/cmd/adb/main.go +++ b/cmd/adb/main.go @@ -36,9 +36,18 @@ var ( pushRemoteArg = pushCommand.Arg("remote", "Path of destination file on device.").Required().String() ) +var server goadb.Server + func main() { var exitCode int + var err error + server, err = goadb.NewServer(goadb.ServerConfig{}) + if err != nil { + fmt.Fprintln(os.Stderr, "error:", err) + os.Exit(1) + } + switch kingpin.Parse() { case "devices": exitCode = listDevices(*devicesLongFlag) @@ -62,7 +71,7 @@ func parseDevice() goadb.DeviceDescriptor { } func listDevices(long bool) int { - client := goadb.NewHostClient(goadb.ClientConfig{}) + client := goadb.NewHostClient(server) devices, err := client.ListDevices() if err != nil { fmt.Fprintln(os.Stderr, "error:", err) @@ -99,7 +108,7 @@ func runShellCommand(commandAndArgs []string, device goadb.DeviceDescriptor) int args = commandAndArgs[1:] } - client := goadb.NewDeviceClient(goadb.ClientConfig{}, device) + client := goadb.NewDeviceClient(server, device) output, err := client.RunCommand(command, args...) if err != nil { fmt.Fprintln(os.Stderr, "error:", err) @@ -121,7 +130,7 @@ func pull(showProgress bool, remotePath, localPath string, device goadb.DeviceDe localPath = filepath.Base(remotePath) } - client := goadb.NewDeviceClient(goadb.ClientConfig{}, device) + client := goadb.NewDeviceClient(server, device) info, err := client.Stat(remotePath) if util.HasErrCode(err, util.FileNoExistError) { @@ -194,7 +203,7 @@ func push(showProgress bool, localPath, remotePath string, device goadb.DeviceDe } defer localFile.Close() - client := goadb.NewDeviceClient(goadb.ClientConfig{}, device) + client := goadb.NewDeviceClient(server, device) writer, err := client.OpenWrite(remotePath, perms, mtime) if err != nil { fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, err) diff --git a/cmd/demo/demo.go b/cmd/demo/demo.go index 17a5d6c..fa0ff42 100644 --- a/cmd/demo/demo.go +++ b/cmd/demo/demo.go @@ -12,15 +12,26 @@ import ( "github.com/zach-klippenstein/goadb/util" ) -var port = flag.Int("p", adb.AdbPort, "") +var ( + port = flag.Int("p", adb.AdbPort, "") + + server adb.Server +) func main() { flag.Parse() - client := adb.NewHostClient(adb.ClientConfig{}) - + var err error + server, err = adb.NewServer(adb.ServerConfig{ + Port: *port, + }) + if err != nil { + log.Fatal(err) + } fmt.Println("Starting server…") - adb.StartServer() + server.Start() + + client := adb.NewHostClient(server) serverVersion, err := client.GetServerVersion() if err != nil { @@ -51,7 +62,7 @@ func main() { fmt.Println() fmt.Println("Watching for device state changes.") - watcher := adb.NewDeviceWatcher(adb.ClientConfig{}) + watcher := adb.NewDeviceWatcher(server) for event := range watcher.C() { fmt.Printf("\t[%s]%+v\n", time.Now(), event) } @@ -77,7 +88,7 @@ func printErr(err error) { } func PrintDeviceInfoAndError(descriptor adb.DeviceDescriptor) { - device := adb.NewDeviceClient(adb.ClientConfig{}, descriptor) + device := adb.NewDeviceClient(server, descriptor) if err := PrintDeviceInfo(device); err != nil { log.Println(err) } diff --git a/cmd/raw-adb/raw-adb.go b/cmd/raw-adb/raw-adb.go index a79a98a..aa70a92 100644 --- a/cmd/raw-adb/raw-adb.go +++ b/cmd/raw-adb/raw-adb.go @@ -49,7 +49,14 @@ func readLine() string { } func doCommand(cmd string) error { - conn, err := goadb.NewDialer("", *port).Dial() + server, err := goadb.NewServer(goadb.ServerConfig{ + Port: *port, + }) + if err != nil { + log.Fatal(err) + } + + conn, err := server.Dial() if err != nil { log.Fatal(err) } diff --git a/device_client.go b/device_client.go index ed2598c..3f9a30a 100644 --- a/device_client.go +++ b/device_client.go @@ -17,18 +17,18 @@ var MtimeOfClose = time.Time{} // DeviceClient communicates with a specific Android device. type DeviceClient struct { - config ClientConfig + server Server descriptor DeviceDescriptor // Used to get device info. deviceListFunc func() ([]*DeviceInfo, error) } -func NewDeviceClient(config ClientConfig, descriptor DeviceDescriptor) *DeviceClient { +func NewDeviceClient(server Server, descriptor DeviceDescriptor) *DeviceClient { return &DeviceClient{ - config: config.sanitized(), + server: server, descriptor: descriptor, - deviceListFunc: NewHostClient(config).ListDevices, + deviceListFunc: NewHostClient(server).ListDevices, } } @@ -194,7 +194,7 @@ func (c *DeviceClient) OpenWrite(path string, perms os.FileMode, mtime time.Time // getAttribute returns the first message returned by the server by running // :, where host-prefix is determined from the DeviceDescriptor. func (c *DeviceClient) getAttribute(attr string) (string, error) { - resp, err := roundTripSingleResponse(c.config.Dialer, + resp, err := roundTripSingleResponse(c.server, fmt.Sprintf("%s:%s", c.descriptor.getHostPrefix(), attr)) if err != nil { return "", err @@ -222,7 +222,7 @@ func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) { // dialDevice switches the connection to communicate directly with the device // by requesting the transport defined by the DeviceDescriptor. func (c *DeviceClient) dialDevice() (*wire.Conn, error) { - conn, err := c.config.Dialer.Dial() + conn, err := c.server.Dial() if err != nil { return nil, err } diff --git a/device_client_test.go b/device_client_test.go index 3c43eec..3a5b6f4 100644 --- a/device_client_test.go +++ b/device_client_test.go @@ -13,12 +13,7 @@ func TestGetAttribute(t *testing.T) { Status: wire.StatusSuccess, Messages: []string{"value"}, } - client := NewDeviceClient( - ClientConfig{ - Dialer: s, - }, - DeviceWithSerial("serial"), - ) + client := NewDeviceClient(s, DeviceWithSerial("serial")) v, err := client.getAttribute("attr") assert.Equal(t, "host-serial:serial:attr", s.Requests[0]) @@ -60,11 +55,9 @@ func TestGetDeviceInfo(t *testing.T) { func newDeviceClientWithDeviceLister(serial string, deviceLister func() ([]*DeviceInfo, error)) *DeviceClient { client := NewDeviceClient( - ClientConfig{ - Dialer: &MockServer{ - Status: wire.StatusSuccess, - Messages: []string{serial}, - }, + &MockServer{ + Status: wire.StatusSuccess, + Messages: []string{serial}, }, DeviceWithSerial(serial), ) @@ -77,12 +70,7 @@ func TestRunCommandNoArgs(t *testing.T) { Status: wire.StatusSuccess, Messages: []string{"output"}, } - client := NewDeviceClient( - ClientConfig{ - Dialer: s, - }, - AnyDevice(), - ) + client := NewDeviceClient(s, AnyDevice()) v, err := client.RunCommand("cmd") assert.Equal(t, "host:transport-any", s.Requests[0]) diff --git a/device_watcher.go b/device_watcher.go index 2f44166..095b151 100644 --- a/device_watcher.go +++ b/device_watcher.go @@ -2,10 +2,10 @@ package goadb import ( "log" + "math/rand" "runtime" "strings" "sync/atomic" - "math/rand" "time" "github.com/zach-klippenstein/goadb/util" @@ -59,22 +59,18 @@ var deviceStateStrings = map[string]DeviceState{ } type deviceWatcherImpl struct { - config ClientConfig + server Server // If an error occurs, it is stored here and eventChan is close immediately after. err atomic.Value eventChan chan DeviceStateChangedEvent - - // Function to start the server if it's not running or dies. - startServer func() error } -func NewDeviceWatcher(config ClientConfig) *DeviceWatcher { +func NewDeviceWatcher(server Server) *DeviceWatcher { watcher := &DeviceWatcher{&deviceWatcherImpl{ - config: config.sanitized(), - eventChan: make(chan DeviceStateChangedEvent), - startServer: StartServer, + server: server, + eventChan: make(chan DeviceStateChangedEvent), }} runtime.SetFinalizer(watcher, func(watcher *DeviceWatcher) { @@ -134,7 +130,7 @@ func publishDevices(watcher *deviceWatcherImpl) { finished := false for { - scanner, err := connectToTrackDevices(watcher.config.Dialer) + scanner, err := connectToTrackDevices(watcher.server) if err != nil { watcher.reportErr(err) return @@ -156,7 +152,7 @@ func publishDevices(watcher *deviceWatcherImpl) { log.Printf("[DeviceWatcher] server died, restarting in %s…", delay) time.Sleep(delay) - if err := watcher.startServer(); err != nil { + if err := watcher.server.Start(); err != nil { log.Println("[DeviceWatcher] error restarting server, giving up") watcher.reportErr(err) return @@ -169,8 +165,8 @@ func publishDevices(watcher *deviceWatcherImpl) { } } -func connectToTrackDevices(dialer Dialer) (wire.Scanner, error) { - conn, err := dialer.Dial() +func connectToTrackDevices(server Server) (wire.Scanner, error) { + conn, err := server.Dial() if err != nil { return nil, err } diff --git a/device_watcher_test.go b/device_watcher_test.go index df58595..27f0733 100644 --- a/device_watcher_test.go +++ b/device_watcher_test.go @@ -1,7 +1,6 @@ package goadb import ( - "log" "testing" "github.com/stretchr/testify/assert" @@ -207,8 +206,7 @@ func TestWentOffline(t *testing.T) { } func TestPublishDevicesRestartsServer(t *testing.T) { - starter := &MockServerStarter{} - dialer := &MockServer{ + server := &MockServer{ Status: wire.StatusSuccess, Errs: []error{ nil, nil, nil, // Successful dial. @@ -217,34 +215,17 @@ func TestPublishDevicesRestartsServer(t *testing.T) { }, } watcher := deviceWatcherImpl{ - config: ClientConfig{dialer}, - eventChan: make(chan DeviceStateChangedEvent), - startServer: starter.StartServer, + server: server, + eventChan: make(chan DeviceStateChangedEvent), } publishDevices(&watcher) - assert.Empty(t, dialer.Errs) - assert.Equal(t, []string{"host:track-devices"}, dialer.Requests) - assert.Equal(t, []string{"Dial", "SendMessage", "ReadStatus", "ReadMessage", "Dial"}, dialer.Trace) + assert.Empty(t, server.Errs) + assert.Equal(t, []string{"host:track-devices"}, server.Requests) + assert.Equal(t, []string{"Dial", "SendMessage", "ReadStatus", "ReadMessage", "Start", "Dial"}, server.Trace) err := watcher.err.Load().(*util.Err) assert.Equal(t, util.ServerNotAvailable, err.Code) - assert.Equal(t, 1, starter.startCount) -} - -type MockServerStarter struct { - startCount int - err error -} - -func (s *MockServerStarter) StartServer() error { - log.Printf("Starting mock server") - if s.err == nil { - s.startCount += 1 - return nil - } else { - return s.err - } } func assertContainsOnly(t *testing.T, expected, actual []DeviceStateChangedEvent) { diff --git a/dialer.go b/dialer.go index 83334a6..aa82082 100644 --- a/dialer.go +++ b/dialer.go @@ -1,7 +1,6 @@ package goadb import ( - "fmt" "io" "net" "runtime" @@ -10,61 +9,19 @@ import ( "github.com/zach-klippenstein/goadb/wire" ) -const ( - // Default port the adb server listens on. - AdbPort = 5037 -) - -/* -Dialer knows how to create connections to an adb server. -*/ +// Dialer knows how to create connections to an adb server. type Dialer interface { - Dial() (*wire.Conn, error) + Dial(address string) (*wire.Conn, error) } -/* -NewDialer creates a new Dialer. - -If host is "" or port is 0, "localhost:5037" is used. -*/ -func NewDialer(host string, port int) Dialer { - if host == "" { - host = "localhost" - } - if port == 0 { - port = AdbPort - } - return &netDialer{host, port} -} - -type netDialer struct { - Host string - Port int -} - -func (d *netDialer) String() string { - return fmt.Sprintf("netDialer(%s:%d)", d.Host, d.Port) -} +type tcpDialer struct{} // Dial connects to the adb server on the host and port set on the netDialer. // The zero-value will connect to the default, localhost:5037. -func (d *netDialer) Dial() (*wire.Conn, error) { - host := d.Host - port := d.Port - - address := fmt.Sprintf("%s:%d", host, port) +func (tcpDialer) Dial(address string) (*wire.Conn, error) { netConn, err := net.Dial("tcp", address) if err != nil { - // Attempt to start the server and try again. - if err = StartServer(); err != nil { - return nil, util.WrapErrorf(err, util.ServerNotAvailable, "error starting server") - } - - address = fmt.Sprintf("%s:%d", host, port) - netConn, err = net.Dial("tcp", address) - if err != nil { - return nil, util.WrapErrorf(err, util.ServerNotAvailable, "error dialing %s", address) - } + return nil, util.WrapErrorf(err, util.ServerNotAvailable, "error dialing %s", address) } // net.Conn can't be closed more than once, but wire.Conn will try to close both sender and scanner @@ -84,13 +41,3 @@ func (d *netDialer) Dial() (*wire.Conn, error) { Sender: wire.NewSender(safeConn), }, nil } - -func roundTripSingleResponse(d Dialer, req string) ([]byte, error) { - conn, err := d.Dial() - if err != nil { - return nil, err - } - defer conn.Close() - - return conn.RoundTripSingleResponse([]byte(req)) -} diff --git a/host_client.go b/host_client.go index ebca1b7..38a6ae5 100644 --- a/host_client.go +++ b/host_client.go @@ -19,16 +19,16 @@ See list of services at https://android.googlesource.com/platform/system/core/+/ */ // TODO(z): Finish implementing host services. type HostClient struct { - config ClientConfig + server Server } -func NewHostClient(config ClientConfig) *HostClient { - return &HostClient{config.sanitized()} +func NewHostClient(server Server) *HostClient { + return &HostClient{server} } // GetServerVersion asks the ADB server for its internal version number. func (c *HostClient) GetServerVersion() (int, error) { - resp, err := roundTripSingleResponse(c.config.Dialer, "host:version") + resp, err := roundTripSingleResponse(c.server, "host:version") if err != nil { return 0, wrapClientError(err, c, "GetServerVersion") } @@ -47,7 +47,7 @@ Corresponds to the command: adb kill-server */ func (c *HostClient) KillServer() error { - conn, err := c.config.Dialer.Dial() + conn, err := c.server.Dial() if err != nil { return wrapClientError(err, c, "KillServer") } @@ -67,7 +67,7 @@ Corresponds to the command: adb devices */ func (c *HostClient) ListDeviceSerials() ([]string, error) { - resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices") + resp, err := roundTripSingleResponse(c.server, "host:devices") if err != nil { return nil, wrapClientError(err, c, "ListDeviceSerials") } @@ -91,7 +91,7 @@ Corresponds to the command: adb devices -l */ func (c *HostClient) ListDevices() ([]*DeviceInfo, error) { - resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices-l") + resp, err := roundTripSingleResponse(c.server, "host:devices-l") if err != nil { return nil, wrapClientError(err, c, "ListDevices") } diff --git a/host_client_test.go b/host_client_test.go index de6226e..7056cae 100644 --- a/host_client_test.go +++ b/host_client_test.go @@ -1,12 +1,9 @@ package goadb import ( - "io" - "strings" "testing" "github.com/stretchr/testify/assert" - "github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/wire" ) @@ -15,113 +12,10 @@ func TestGetServerVersion(t *testing.T) { Status: wire.StatusSuccess, Messages: []string{"000a"}, } - client := NewHostClient(ClientConfig{ - Dialer: s, - }) + client := NewHostClient(s) v, err := client.GetServerVersion() assert.Equal(t, "host:version", s.Requests[0]) assert.NoError(t, err) assert.Equal(t, 10, v) } - -// MockServer implements Dialer, Scanner, and Sender. -type MockServer struct { - // Each time an operation is performed, if this slice is non-empty, the head element - // of this slice is returned and removed from the slice. If the head is nil, it is removed - // but not returned. - Errs []error - - Status string - - // Messages are returned from read calls in order, each preceded by a length header. - Messages []string - nextMsgIndex int - - // Each message passed to a send call is appended to this slice. - Requests []string - - // Each time an operaiton is performed, its name is appended to this slice. - Trace []string -} - -func (s *MockServer) Dial() (*wire.Conn, error) { - s.logMethod("Dial") - if err := s.getNextErrToReturn(); err != nil { - return nil, err - } - return wire.NewConn(s, s), nil -} - -func (s *MockServer) ReadStatus(req string) (string, error) { - s.logMethod("ReadStatus") - if err := s.getNextErrToReturn(); err != nil { - return "", err - } - return s.Status, nil -} - -func (s *MockServer) ReadMessage() ([]byte, error) { - s.logMethod("ReadMessage") - if err := s.getNextErrToReturn(); err != nil { - return nil, err - } - if s.nextMsgIndex >= len(s.Messages) { - return nil, util.WrapErrorf(io.EOF, util.NetworkError, "") - } - - s.nextMsgIndex++ - return []byte(s.Messages[s.nextMsgIndex-1]), nil -} - -func (s *MockServer) ReadUntilEof() ([]byte, error) { - s.logMethod("ReadUntilEof") - if err := s.getNextErrToReturn(); err != nil { - return nil, err - } - - var data []string - for ; s.nextMsgIndex < len(s.Messages); s.nextMsgIndex++ { - data = append(data, s.Messages[s.nextMsgIndex]) - } - return []byte(strings.Join(data, "")), nil -} - -func (s *MockServer) SendMessage(msg []byte) error { - s.logMethod("SendMessage") - if err := s.getNextErrToReturn(); err != nil { - return err - } - s.Requests = append(s.Requests, string(msg)) - return nil -} - -func (s *MockServer) NewSyncScanner() wire.SyncScanner { - s.logMethod("NewSyncScanner") - return nil -} - -func (s *MockServer) NewSyncSender() wire.SyncSender { - s.logMethod("NewSyncSender") - return nil -} - -func (s *MockServer) Close() error { - s.logMethod("Close") - if err := s.getNextErrToReturn(); err != nil { - return err - } - return nil -} - -func (s *MockServer) getNextErrToReturn() (err error) { - if len(s.Errs) > 0 { - err = s.Errs[0] - s.Errs = s.Errs[1:] - } - return -} - -func (s *MockServer) logMethod(name string) { - s.Trace = append(s.Trace, name) -} diff --git a/server.go b/server.go new file mode 100644 index 0000000..ae6aab0 --- /dev/null +++ b/server.go @@ -0,0 +1,146 @@ +package goadb + +import ( + "errors" + "fmt" + "os" + "os/exec" + "strings" + + "github.com/zach-klippenstein/goadb/util" + "github.com/zach-klippenstein/goadb/wire" + "golang.org/x/sys/unix" +) + +const ( + AdbExecutableName = "adb" + + // Default port the adb server listens on. + AdbPort = 5037 +) + +type ServerConfig struct { + // Path to the adb executable. If empty, the PATH environment variable will be searched. + PathToAdb string + + // Host and port the adb server is listening on. + // If not specified, will use the default port on localhost. + Host string + Port int + + // Dialer used to connect to the adb server. + Dialer +} + +// Server knows how to start the adb server and connect to it. +type Server interface { + Start() error + Dial() (*wire.Conn, error) +} + +func roundTripSingleResponse(s Server, req string) ([]byte, error) { + conn, err := s.Dial() + if err != nil { + return nil, err + } + defer conn.Close() + + return conn.RoundTripSingleResponse([]byte(req)) +} + +type realServer struct { + config ServerConfig + fs *filesystem + + // Caches Host:Port so they don't have to be concatenated for every dial. + address string +} + +// NewServer creates a new Server instance. +func NewServer(config ServerConfig) (Server, error) { + return newServer(config, localFilesystem) +} + +func newServer(config ServerConfig, fs *filesystem) (Server, error) { + if config.Dialer == nil { + config.Dialer = tcpDialer{} + } + + if config.Host == "" { + config.Host = "localhost" + } + if config.Port == 0 { + config.Port = AdbPort + } + + if config.PathToAdb == "" { + path, err := fs.LookPath(AdbExecutableName) + if err != nil { + return nil, util.WrapErrorf(err, util.ServerNotAvailable, "could not find %s in PATH", AdbExecutableName) + } + config.PathToAdb = path + } + if err := fs.IsExecutableFile(config.PathToAdb); err != nil { + return nil, util.WrapErrorf(err, util.ServerNotAvailable, "invalid adb executable: %s", config.PathToAdb) + } + + return &realServer{ + config: config, + fs: fs, + address: fmt.Sprintf("%s:%d", config.Host, config.Port), + }, nil +} + +// Dial tries to connect to the server. If the first attempt fails, tries starting the server before +// retrying. If the second attempt fails, returns the error. +func (s *realServer) Dial() (*wire.Conn, error) { + conn, err := s.config.Dial(s.address) + if err != nil { + // Attempt to start the server and try again. + if err = s.Start(); err != nil { + return nil, util.WrapErrorf(err, util.ServerNotAvailable, "error starting server for dial") + } + + conn, err = s.config.Dial(s.address) + if err != nil { + return nil, err + } + } + return conn, nil +} + +// StartServer ensures there is a server running. +func (s *realServer) Start() error { + output, err := s.fs.CmdCombinedOutput(s.config.PathToAdb, "start-server") + outputStr := strings.TrimSpace(string(output)) + return util.WrapErrorf(err, util.ServerNotAvailable, "error starting server: %s\noutput:\n%s", err, outputStr) +} + +// filesystem abstracts interactions with the local filesystem for testability. +type filesystem struct { + // Wraps exec.LookPath. + LookPath func(string) (string, error) + + // Returns nil if path is a regular file and executable by the current user. + IsExecutableFile func(path string) error + + // Wraps exec.Command().CombinedOutput() + CmdCombinedOutput func(name string, arg ...string) ([]byte, error) +} + +var localFilesystem = &filesystem{ + LookPath: exec.LookPath, + IsExecutableFile: func(path string) error { + info, err := os.Stat(path) + if err != nil { + return err + } + if !info.Mode().IsRegular() { + return errors.New("not a regular file") + } + return unix.Access(path, unix.X_OK) + }, + CmdCombinedOutput: func(name string, arg ...string) ([]byte, error) { + return exec.Command(name, arg...).CombinedOutput() + }, +} diff --git a/server_controller.go b/server_controller.go deleted file mode 100644 index a3a3f10..0000000 --- a/server_controller.go +++ /dev/null @@ -1,18 +0,0 @@ -package goadb - -import ( - "os/exec" - "strings" - - "github.com/zach-klippenstein/goadb/util" -) - -/* -StartServer ensures there is a server running. -*/ -func StartServer() error { - cmd := exec.Command("adb", "start-server") - output, err := cmd.CombinedOutput() - outputStr := strings.TrimSpace(string(output)) - return util.WrapErrorf(err, util.ServerNotAvailable, "error starting server: %s\noutput:\n%s", err, outputStr) -} diff --git a/server_mock_test.go b/server_mock_test.go new file mode 100644 index 0000000..2f7dc35 --- /dev/null +++ b/server_mock_test.go @@ -0,0 +1,117 @@ +package goadb + +import ( + "io" + "strings" + + "github.com/zach-klippenstein/goadb/util" + "github.com/zach-klippenstein/goadb/wire" +) + +// MockServer implements Server, Scanner, and Sender. +type MockServer struct { + // Each time an operation is performed, if this slice is non-empty, the head element + // of this slice is returned and removed from the slice. If the head is nil, it is removed + // but not returned. + Errs []error + + Status string + + // Messages are returned from read calls in order, each preceded by a length header. + Messages []string + nextMsgIndex int + + // Each message passed to a send call is appended to this slice. + Requests []string + + // Each time an operation is performed, its name is appended to this slice. + Trace []string +} + +var _ Server = &MockServer{} + +func (s *MockServer) Dial() (*wire.Conn, error) { + s.logMethod("Dial") + if err := s.getNextErrToReturn(); err != nil { + return nil, err + } + return wire.NewConn(s, s), nil +} + +func (s *MockServer) Start() error { + s.logMethod("Start") + return nil +} + +func (s *MockServer) ReadStatus(req string) (string, error) { + s.logMethod("ReadStatus") + if err := s.getNextErrToReturn(); err != nil { + return "", err + } + return s.Status, nil +} + +func (s *MockServer) ReadMessage() ([]byte, error) { + s.logMethod("ReadMessage") + if err := s.getNextErrToReturn(); err != nil { + return nil, err + } + if s.nextMsgIndex >= len(s.Messages) { + return nil, util.WrapErrorf(io.EOF, util.NetworkError, "") + } + + s.nextMsgIndex++ + return []byte(s.Messages[s.nextMsgIndex-1]), nil +} + +func (s *MockServer) ReadUntilEof() ([]byte, error) { + s.logMethod("ReadUntilEof") + if err := s.getNextErrToReturn(); err != nil { + return nil, err + } + + var data []string + for ; s.nextMsgIndex < len(s.Messages); s.nextMsgIndex++ { + data = append(data, s.Messages[s.nextMsgIndex]) + } + return []byte(strings.Join(data, "")), nil +} + +func (s *MockServer) SendMessage(msg []byte) error { + s.logMethod("SendMessage") + if err := s.getNextErrToReturn(); err != nil { + return err + } + s.Requests = append(s.Requests, string(msg)) + return nil +} + +func (s *MockServer) NewSyncScanner() wire.SyncScanner { + s.logMethod("NewSyncScanner") + return nil +} + +func (s *MockServer) NewSyncSender() wire.SyncSender { + s.logMethod("NewSyncSender") + return nil +} + +func (s *MockServer) Close() error { + s.logMethod("Close") + if err := s.getNextErrToReturn(); err != nil { + return err + } + return nil +} + +func (s *MockServer) getNextErrToReturn() (err error) { + if len(s.Errs) > 0 { + err = s.Errs[0] + s.Errs = s.Errs[1:] + } + return +} + +func (s *MockServer) logMethod(name string) { + s.Trace = append(s.Trace, name) +} diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..4bf7a2c --- /dev/null +++ b/server_test.go @@ -0,0 +1,80 @@ +package goadb + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zach-klippenstein/goadb/wire" +) + +func TestNewServer_ZeroConfig(t *testing.T) { + config := ServerConfig{} + fs := &filesystem{ + LookPath: func(name string) (string, error) { + if name == AdbExecutableName { + return "/bin/adb", nil + } + return "", fmt.Errorf("invalid name: %s", name) + }, + IsExecutableFile: func(path string) error { + if path == "/bin/adb" { + return nil + } + return fmt.Errorf("wrong path: %s", path) + }, + } + + serverIf, err := newServer(config, fs) + server := serverIf.(*realServer) + assert.NoError(t, err) + assert.IsType(t, tcpDialer{}, server.config.Dialer) + assert.Equal(t, "localhost", server.config.Host) + assert.Equal(t, AdbPort, server.config.Port) + assert.Equal(t, fmt.Sprintf("localhost:%d", AdbPort), server.address) + assert.Equal(t, "/bin/adb", server.config.PathToAdb) +} + +type MockDialer struct{} + +func (d MockDialer) Dial(address string) (*wire.Conn, error) { + return nil, nil +} + +func TestNewServer_CustomConfig(t *testing.T) { + config := ServerConfig{ + Dialer: MockDialer{}, + Host: "foobar", + Port: 1, + PathToAdb: "/bin/adb", + } + fs := &filesystem{ + IsExecutableFile: func(path string) error { + if path == "/bin/adb" { + return nil + } + return fmt.Errorf("wrong path: %s", path) + }, + } + + serverIf, err := newServer(config, fs) + server := serverIf.(*realServer) + assert.NoError(t, err) + assert.IsType(t, MockDialer{}, server.config.Dialer) + assert.Equal(t, "foobar", server.config.Host) + assert.Equal(t, 1, server.config.Port) + assert.Equal(t, fmt.Sprintf("foobar:1"), server.address) + assert.Equal(t, "/bin/adb", server.config.PathToAdb) +} + +func TestNewServer_AdbNotFound(t *testing.T) { + config := ServerConfig{} + fs := &filesystem{ + LookPath: func(name string) (string, error) { + return "", fmt.Errorf("executable not found: %s", name) + }, + } + + _, err := newServer(config, fs) + assert.EqualError(t, err, "ServerNotAvailable: could not find adb in PATH") +}