diff --git a/cmd/adb/main.go b/cmd/adb/main.go index b491679..2dfddf2 100644 --- a/cmd/adb/main.go +++ b/cmd/adb/main.go @@ -131,7 +131,7 @@ func pull(showProgress bool, remotePath, localPath string, device goadb.DeviceDe remoteFile, err := client.OpenRead(remotePath) if err != nil { - fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, err) + fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, util.ErrorWithCauseChain(err)) return 1 } defer remoteFile.Close() diff --git a/cmd/raw-adb/raw-adb.go b/cmd/raw-adb/raw-adb.go index 231e830..a79a98a 100644 --- a/cmd/raw-adb/raw-adb.go +++ b/cmd/raw-adb/raw-adb.go @@ -59,7 +59,7 @@ func doCommand(cmd string) error { return err } - status, err := conn.ReadStatus() + status, err := conn.ReadStatus("") if err != nil { return err } diff --git a/device_client.go b/device_client.go index cd3d3c7..d69102f 100644 --- a/device_client.go +++ b/device_client.go @@ -112,7 +112,7 @@ func (c *DeviceClient) RunCommand(cmd string, args ...string) (string, error) { if err = conn.SendMessage([]byte(req)); err != nil { return "", wrapClientError(err, c, "RunCommand") } - if err = wire.ReadStatusFailureAsError(conn, req); err != nil { + if _, err = conn.ReadStatus(req); err != nil { return "", wrapClientError(err, c, "RunCommand") } @@ -192,7 +192,7 @@ func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) { if err := wire.SendMessageString(conn, "sync:"); err != nil { return nil, err } - if err := wire.ReadStatusFailureAsError(conn, "sync"); err != nil { + if _, err := conn.ReadStatus("sync"); err != nil { return nil, err } @@ -213,7 +213,7 @@ func (c *DeviceClient) dialDevice() (*wire.Conn, error) { return nil, util.WrapErrf(err, "error connecting to device '%s'", c.descriptor) } - if err = wire.ReadStatusFailureAsError(conn, req); err != nil { + if _, err = conn.ReadStatus(req); err != nil { conn.Close() return nil, err } diff --git a/device_watcher.go b/device_watcher.go index 48fdf95..2f44166 100644 --- a/device_watcher.go +++ b/device_watcher.go @@ -180,7 +180,7 @@ func connectToTrackDevices(dialer Dialer) (wire.Scanner, error) { return nil, err } - if err := wire.ReadStatusFailureAsError(conn, "host:track-devices"); err != nil { + if _, err := conn.ReadStatus("host:track-devices"); err != nil { conn.Close() return nil, err } diff --git a/dir_entries.go b/dir_entries.go index a213b6b..202609e 100644 --- a/dir_entries.go +++ b/dir_entries.go @@ -74,16 +74,16 @@ func (entries *DirEntries) Close() error { } func readNextDirListEntry(s wire.SyncScanner) (entry *DirEntry, done bool, err error) { - id, err := s.ReadOctetString() + status, err := s.ReadStatus("dir-entry") if err != nil { return } - if id == "DONE" { + if status == "DONE" { done = true return - } else if id != "DENT" { - err = fmt.Errorf("error reading dir entries: expected dir entry ID 'DENT', but got '%s'", id) + } else if status != "DENT" { + err = fmt.Errorf("error reading dir entries: expected dir entry ID 'DENT', but got '%s'", status) return } diff --git a/host_client_test.go b/host_client_test.go index 3acacdf..de6226e 100644 --- a/host_client_test.go +++ b/host_client_test.go @@ -32,7 +32,7 @@ type MockServer struct { // but not returned. Errs []error - Status wire.StatusCode + Status string // Messages are returned from read calls in order, each preceded by a length header. Messages []string @@ -53,7 +53,7 @@ func (s *MockServer) Dial() (*wire.Conn, error) { return wire.NewConn(s, s), nil } -func (s *MockServer) ReadStatus() (wire.StatusCode, error) { +func (s *MockServer) ReadStatus(req string) (string, error) { s.logMethod("ReadStatus") if err := s.getNextErrToReturn(); err != nil { return "", err diff --git a/sync_client.go b/sync_client.go index 4ec31a0..c3497d6 100644 --- a/sync_client.go +++ b/sync_client.go @@ -20,7 +20,7 @@ func stat(conn *wire.SyncConn, path string) (*DirEntry, error) { return nil, err } - id, err := conn.ReadOctetString() + id, err := conn.ReadStatus("stat") if err != nil { return nil, err } @@ -49,8 +49,7 @@ func receiveFile(conn *wire.SyncConn, path string) (io.ReadCloser, error) { if err := conn.SendString(path); err != nil { return nil, err } - - return newSyncFileReader(conn), nil + return newSyncFileReader(conn) } func readStat(s wire.SyncScanner) (entry *DirEntry, err error) { diff --git a/sync_file_reader.go b/sync_file_reader.go index 6ab1c04..69f474d 100644 --- a/sync_file_reader.go +++ b/sync_file_reader.go @@ -4,6 +4,7 @@ import ( "fmt" "io" + "github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/wire" ) @@ -18,10 +19,17 @@ type syncFileReader struct { var _ io.ReadCloser = &syncFileReader{} -func newSyncFileReader(s wire.SyncScanner) io.ReadCloser { - return &syncFileReader{ +func newSyncFileReader(s wire.SyncScanner) (r io.ReadCloser, err error) { + r = &syncFileReader{ scanner: s, } + + // Read the header for the first chunk to consume any errors. + if _, err = r.Read([]byte{}); err != nil { + r.Close() + return nil, err + } + return } func (r *syncFileReader) Read(buf []byte) (n int, err error) { @@ -35,6 +43,13 @@ func (r *syncFileReader) Read(buf []byte) (n int, err error) { r.chunkReader = chunkReader } + if len(buf) == 0 { + // Read can be called with an empty buffer to read the next chunk and check for errors. + // However, net.Conn.Read seems to return EOF when given an empty buffer, so we need to + // handle that case ourselves. + return 0, nil + } + n, err = r.chunkReader.Read(buf) if err == io.EOF { // End of current chunk, don't return an error, the next chunk will be @@ -53,17 +68,27 @@ func (r *syncFileReader) Close() error { // readNextChunk creates an io.LimitedReader for the next chunk of data, // and returns io.EOF if the last chunk has been read. func readNextChunk(r wire.SyncScanner) (io.Reader, error) { - id, err := r.ReadOctetString() + status, err := r.ReadStatus("read-chunk") if err != nil { + if wire.IsAdbServerErrorMatching(err, readFileNotFoundPredicate) { + return nil, util.Errorf(util.FileNoExistError, "no such file or directory") + } return nil, err } - switch id { - case "DATA": + switch status { + case wire.StatusSyncData: return r.ReadBytes() - case "DONE": + case wire.StatusSyncDone: return nil, io.EOF default: - return nil, fmt.Errorf("expected chunk id 'DATA', but got '%s'", id) + return nil, fmt.Errorf("expected chunk id '%s' or '%s', but got '%s'", + wire.StatusSyncData, wire.StatusSyncDone, []byte(status)) } } + +// readFileNotFoundPredicate returns true if s is the adb server error message returned +// when trying to open a file that doesn't exist. +func readFileNotFoundPredicate(s string) bool { + return s == "No such file or directory" +} diff --git a/sync_file_reader_test.go b/sync_file_reader_test.go index 67033ba..d2b2a55 100644 --- a/sync_file_reader_test.go +++ b/sync_file_reader_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/wire" ) @@ -43,16 +44,17 @@ func TestReadNextChunkInvalidChunkId(t *testing.T) { // Read 1st chunk _, err := readNextChunk(s) - assert.EqualError(t, err, "expected chunk id 'DATA', but got 'ATAD'") + assert.EqualError(t, err, "expected chunk id 'DATA' or 'DONE', but got 'ATAD'") } func TestReadMultipleCalls(t *testing.T) { s := wire.NewSyncScanner(strings.NewReader( "DATA\006\000\000\000hello DATA\005\000\000\000worldDONE")) - reader := newSyncFileReader(s) + reader, err := newSyncFileReader(s) + assert.NoError(t, err) firstByte := make([]byte, 1) - _, err := io.ReadFull(reader, firstByte) + _, err = io.ReadFull(reader, firstByte) assert.NoError(t, err) assert.Equal(t, "h", string(firstByte)) @@ -73,10 +75,26 @@ func TestReadMultipleCalls(t *testing.T) { func TestReadAll(t *testing.T) { s := wire.NewSyncScanner(strings.NewReader( "DATA\006\000\000\000hello DATA\005\000\000\000worldDONE")) - reader := newSyncFileReader(s) + reader, err := newSyncFileReader(s) + assert.NoError(t, err) buf := make([]byte, 20) - _, err := io.ReadFull(reader, buf) + _, err = io.ReadFull(reader, buf) assert.Equal(t, io.ErrUnexpectedEOF, err) assert.Equal(t, "hello world\000", string(buf[:12])) } + +func TestReadError(t *testing.T) { + s := wire.NewSyncScanner(strings.NewReader( + "FAIL\004\000\000\000fail")) + _, err := newSyncFileReader(s) + assert.EqualError(t, err, "AdbError: server error for read-chunk request: fail ({Request:read-chunk ServerMsg:fail})") +} + +func TestReadErrorNotFound(t *testing.T) { + s := wire.NewSyncScanner(strings.NewReader( + "FAIL\031\000\000\000No such file or directory")) + _, err := newSyncFileReader(s) + assert.True(t, util.HasErrCode(err, util.FileNoExistError)) + assert.EqualError(t, err, "FileNoExistError: no such file or directory") +} diff --git a/wire/conn.go b/wire/conn.go index 21f781f..43a4db4 100644 --- a/wire/conn.go +++ b/wire/conn.go @@ -55,7 +55,7 @@ func (conn *Conn) RoundTripSingleResponse(req []byte) (resp []byte, err error) { return nil, err } - if err = ReadStatusFailureAsError(conn, string(req)); err != nil { + if _, err = conn.ReadStatus(string(req)); err != nil { return nil, err } diff --git a/wire/scanner.go b/wire/scanner.go index bf5d0c5..e6919cd 100644 --- a/wire/scanner.go +++ b/wire/scanner.go @@ -1,6 +1,7 @@ package wire import ( + "encoding/binary" "io" "io/ioutil" "strconv" @@ -12,16 +13,23 @@ import ( // StatusCodes are returned by the server. If the code indicates failure, the // next message will be the error. -type StatusCode string - const ( - StatusSuccess StatusCode = "OKAY" - StatusFailure = "FAIL" - StatusNone = "" + StatusSuccess string = "OKAY" + StatusFailure = "FAIL" + StatusSyncData = "DATA" + StatusSyncDone = "DONE" + StatusNone = "" ) -func (status StatusCode) IsSuccess() bool { - return status == StatusSuccess +func isFailureStatus(status string) bool { + return status == StatusFailure +} + +type StatusReader interface { + // Reads a 4-byte status string and returns it. + // If the status string is StatusFailure, reads the error message from the server + // and returns it as an util.AdbError. + ReadStatus(req string) (string, error) } /* @@ -29,13 +37,12 @@ Scanner reads tokens from a server. See Conn for more details. */ type Scanner interface { - ReadStatus() (StatusCode, error) + io.Closer + StatusReader ReadMessage() ([]byte, error) ReadUntilEof() ([]byte, error) NewSyncScanner() SyncScanner - - Close() error } type realScanner struct { @@ -54,36 +61,12 @@ func ReadMessageString(s Scanner) (string, error) { return string(msg), nil } -func (s *realScanner) ReadStatus() (StatusCode, error) { - status := make([]byte, 4) - n, err := io.ReadFull(s.reader, status) - - if err != nil && err != io.ErrUnexpectedEOF { - return "", util.WrapErrorf(err, util.NetworkError, "error reading status") - } else if err == io.ErrUnexpectedEOF { - return StatusCode(status), errIncompleteMessage("status", n, 4) - } - - return StatusCode(status), nil +func (s *realScanner) ReadStatus(req string) (string, error) { + return readStatusFailureAsError(s.reader, req, readHexLength) } func (s *realScanner) ReadMessage() ([]byte, error) { - var err error - - length, err := s.readLength() - if err != nil { - return nil, err - } - - data := make([]byte, length) - n, err := io.ReadFull(s.reader, data) - - if err != nil && err != io.ErrUnexpectedEOF { - return data, util.WrapErrorf(err, util.NetworkError, "error reading message data") - } else if err == io.ErrUnexpectedEOF { - return data, errIncompleteMessage("message data", n, length) - } - return data, nil + return readMessage(s.reader, readHexLength) } func (s *realScanner) ReadUntilEof() ([]byte, error) { @@ -102,9 +85,75 @@ func (s *realScanner) Close() error { return util.WrapErrorf(s.reader.Close(), util.NetworkError, "error closing scanner") } -func (s *realScanner) readLength() (int, error) { +var _ Scanner = &realScanner{} + +// lengthReader is a func that readMessage uses to read message length. +// See readHexLength and readInt32. +type lengthReader func(io.Reader) (int, error) + +// Reads the status, and if failure, reads the message and returns it as an error. +// If the status is success, doesn't read the message. +// req is just used to populate the AdbError, and can be nil. +// messageLengthReader is the function passed to readMessage if the status is failure. +func readStatusFailureAsError(r io.Reader, req string, messageLengthReader lengthReader) (string, error) { + status, err := readOctetString(req, r) + if err != nil { + return "", util.WrapErrorf(err, util.NetworkError, "error reading status for %s", req) + } + + if isFailureStatus(status) { + msg, err := readMessage(r, messageLengthReader) + if err != nil { + return "", util.WrapErrorf(err, util.NetworkError, + "server returned error for %s, but couldn't read the error message", req) + } + + return "", adbServerError(req, string(msg)) + } + + return status, nil +} + +func readOctetString(description string, r io.Reader) (string, error) { + octet := make([]byte, 4) + n, err := io.ReadFull(r, octet) + + if err == io.ErrUnexpectedEOF { + return "", errIncompleteMessage(description, n, 4) + } else if err != nil { + return "", util.WrapErrorf(err, util.NetworkError, "error reading "+description) + } + + return string(octet), nil +} + +// readMessage reads a length from r, then reads length bytes and returns them. +// lengthReader is the function used to read the length. Most operations encode +// length as a hex string (readHexLength), but sync operations use little-endian +// binary encoding (readInt32). +func readMessage(r io.Reader, lengthReader lengthReader) ([]byte, error) { + var err error + + length, err := lengthReader(r) + if err != nil { + return nil, err + } + + data := make([]byte, length) + n, err := io.ReadFull(r, data) + + if err != nil && err != io.ErrUnexpectedEOF { + return data, util.WrapErrorf(err, util.NetworkError, "error reading message data") + } else if err == io.ErrUnexpectedEOF { + return data, errIncompleteMessage("message data", n, length) + } + return data, nil +} + +// readHexLength reads the next 4 bytes from r as an ASCII hex-encoded length and parses them into an int. +func readHexLength(r io.Reader) (int, error) { lengthHex := make([]byte, 4) - n, err := io.ReadFull(s.reader, lengthHex) + n, err := io.ReadFull(r, lengthHex) if err != nil { return 0, errIncompleteMessage("length", n, 4) } @@ -122,4 +171,10 @@ func (s *realScanner) readLength() (int, error) { return int(length), nil } -var _ Scanner = &realScanner{} +// readInt32 reads the next 4 bytes from r as a little-endian integer. +// Returns an int instead of an int32 to match the lengthReader type. +func readInt32(r io.Reader) (int, error) { + var value int32 + err := binary.Read(r, binary.LittleEndian, &value) + return int(value), err +} diff --git a/wire/scanner_test.go b/wire/scanner_test.go index 5d53d4b..cde6499 100644 --- a/wire/scanner_test.go +++ b/wire/scanner_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "io" + "io/ioutil" "testing" "github.com/stretchr/testify/assert" @@ -11,109 +12,120 @@ import ( ) func TestReadStatusOkay(t *testing.T) { - s := NewScannerString("OKAYd") - status, err := s.ReadStatus() + s := newEofReader("OKAYd") + status, err := readStatusFailureAsError(s, "", readHexLength) assert.NoError(t, err) - assert.True(t, status.IsSuccess()) + assert.False(t, isFailureStatus(status)) assertNotEof(t, s) } func TestReadIncompleteStatus(t *testing.T) { - s := NewScannerString("oka") - _, err := s.ReadStatus() - assert.Equal(t, errIncompleteMessage("status", 3, 4), err) + s := newEofReader("oka") + _, err := readStatusFailureAsError(s, "", readHexLength) + assert.EqualError(t, err, "NetworkError: error reading status for ") + assert.Equal(t, errIncompleteMessage("", 3, 4), err.(*util.Err).Cause) + assertEof(t, s) +} + +func TestReadFailureIncompleteStatus(t *testing.T) { + s := newEofReader("FAIL") + _, err := readStatusFailureAsError(s, "req", readHexLength) + assert.EqualError(t, err, "NetworkError: server returned error for req, but couldn't read the error message") + assert.Error(t, err.(*util.Err).Cause) + assertEof(t, s) +} + +func TestReadFailureEmptyStatus(t *testing.T) { + s := newEofReader("FAIL0000") + _, err := readStatusFailureAsError(s, "", readHexLength) + assert.EqualError(t, err, "AdbError: server error: ({Request: ServerMsg:})") + assert.NoError(t, err.(*util.Err).Cause) + assertEof(t, s) +} + +func TestReadFailureStatus(t *testing.T) { + s := newEofReader("FAIL0004fail") + _, err := readStatusFailureAsError(s, "", readHexLength) + assert.EqualError(t, err, "AdbError: server error: fail ({Request: ServerMsg:fail})") + assert.NoError(t, err.(*util.Err).Cause) + assertEof(t, s) +} + +func TestReadMessage(t *testing.T) { + s := newEofReader("0005hello") + msg, err := readMessage(s, readHexLength) + assert.NoError(t, err) + assert.Len(t, msg, 5) + assert.Equal(t, "hello", string(msg)) + assertEof(t, s) +} + +func TestReadMessageWithExtraData(t *testing.T) { + s := newEofReader("0005hellothere") + msg, err := readMessage(s, readHexLength) + assert.NoError(t, err) + assert.Len(t, msg, 5) + assert.Equal(t, "hello", string(msg)) + assertNotEof(t, s) +} + +func TestReadLongerMessage(t *testing.T) { + s := newEofReader("001b192.168.56.101:5555 device\n") + msg, err := readMessage(s, readHexLength) + assert.NoError(t, err) + assert.Len(t, msg, 27) + assert.Equal(t, "192.168.56.101:5555 device\n", string(msg)) + assertEof(t, s) +} + +func TestReadEmptyMessage(t *testing.T) { + s := newEofReader("0000") + msg, err := readMessage(s, readHexLength) + assert.NoError(t, err) + assert.Equal(t, "", string(msg)) + assertEof(t, s) +} + +func TestReadIncompleteMessage(t *testing.T) { + s := newEofReader("0005hel") + msg, err := readMessage(s, readHexLength) + assert.Error(t, err) + assert.Equal(t, errIncompleteMessage("message data", 3, 5), err) + assert.Equal(t, "hel\000\000", string(msg)) assertEof(t, s) } func TestReadLength(t *testing.T) { - s := NewScannerString("000a") - l, err := s.readLength() + s := newEofReader("000a") + l, err := readHexLength(s) assert.NoError(t, err) assert.Equal(t, 10, l) assertEof(t, s) } -func TestReadIncompleteLength(t *testing.T) { - s := NewScannerString("aaa") - _, err := s.readLength() +func TestReadLengthIncompleteLength(t *testing.T) { + s := newEofReader("aaa") + _, err := readHexLength(s) assert.Equal(t, errIncompleteMessage("length", 3, 4), err) assertEof(t, s) } -func TestReadMessage(t *testing.T) { - s := NewScannerString("0005hello") - msg, err := ReadMessageString(s) - assert.NoError(t, err) - assert.Len(t, msg, 5) - assert.Equal(t, "hello", msg) - assertEof(t, s) -} - -func TestReadMessageWithExtraData(t *testing.T) { - s := NewScannerString("0005hellothere") - msg, err := ReadMessageString(s) - assert.NoError(t, err) - assert.Len(t, msg, 5) - assert.Equal(t, "hello", msg) - assertNotEof(t, s) -} - -func TestReadLongerMessage(t *testing.T) { - s := NewScannerString("001b192.168.56.101:5555 device\n") - msg, err := ReadMessageString(s) - assert.NoError(t, err) - assert.Len(t, msg, 27) - assert.Equal(t, "192.168.56.101:5555 device\n", msg) - assertEof(t, s) -} - -func TestReadEmptyMessage(t *testing.T) { - s := NewScannerString("0000") - msg, err := ReadMessageString(s) - assert.NoError(t, err) - assert.Equal(t, "", msg) - assertEof(t, s) -} - -func TestReadIncompleteMessage(t *testing.T) { - s := NewScannerString("0005hel") - msg, err := ReadMessageString(s) - assert.Error(t, err) - assert.Equal(t, errIncompleteMessage("message data", 3, 5), err) - assert.Equal(t, "hel\000\000", msg) - assertEof(t, s) -} - -func NewScannerString(str string) *realScanner { - return NewScanner(NewEofBuffer(str)).(*realScanner) -} - -// NewEofBuffer returns a bytes.Buffer of str that returns an EOF error -// at the end of input, instead of just returning 0 bytes read. -func NewEofBuffer(str string) *TestReader { - limitReader := io.LimitReader(bytes.NewBufferString(str), int64(len(str))) - bufReader := bufio.NewReader(limitReader) - return &TestReader{bufReader} -} - -func assertEof(t *testing.T, s *realScanner) { - msg, err := s.ReadMessage() +func assertEof(t *testing.T, r io.Reader) { + msg, err := readMessage(r, readHexLength) assert.True(t, util.HasErrCode(err, util.ConnectionResetError)) assert.Nil(t, msg) } -func assertNotEof(t *testing.T, s *realScanner) { - n, err := s.reader.Read(make([]byte, 1)) +func assertNotEof(t *testing.T, r io.Reader) { + n, err := r.Read(make([]byte, 1)) assert.Equal(t, 1, n) assert.NoError(t, err) } -// TestReader is a wrapper around a bufio.Reader that implements io.Closer. -type TestReader struct { - *bufio.Reader -} - -func (b *TestReader) Close() error { - // No-op. - return nil +// newEofBuffer returns a bytes.Buffer of str that returns an EOF error +// at the end of input, instead of just returning 0 bytes read. +func newEofReader(str string) io.ReadCloser { + limitReader := io.LimitReader(bytes.NewBufferString(str), int64(len(str))) + bufReader := bufio.NewReader(limitReader) + return ioutil.NopCloser(bufReader) } diff --git a/wire/sync_scanner.go b/wire/sync_scanner.go index 803d65a..dce14c8 100644 --- a/wire/sync_scanner.go +++ b/wire/sync_scanner.go @@ -10,8 +10,8 @@ import ( ) type SyncScanner interface { - // ReadOctetString reads a 4-byte string. - ReadOctetString() (string, error) + io.Closer + StatusReader ReadInt32() (int32, error) ReadFileMode() (os.FileMode, error) ReadTime() (time.Time, error) @@ -23,9 +23,6 @@ type SyncScanner interface { // bytes (see io.LimitReader). The returned reader should be fully // read before reading anything off the Scanner again. ReadBytes() (io.Reader, error) - - // Closes the underlying reader. - Close() error } type realSyncScanner struct { @@ -36,33 +33,13 @@ func NewSyncScanner(r io.Reader) SyncScanner { return &realSyncScanner{r} } -func RequireOctetString(s SyncScanner, expected string) error { - actual, err := s.ReadOctetString() - if err != nil { - return util.WrapErrorf(err, util.NetworkError, "expected to read '%s'", expected) - } - if actual != expected { - return util.AssertionErrorf("expected to read '%s', got '%s'", expected, actual) - } - return nil +func (s *realSyncScanner) ReadStatus(req string) (string, error) { + return readStatusFailureAsError(s.Reader, req, readInt32) } -func (s *realSyncScanner) ReadOctetString() (string, error) { - octet := make([]byte, 4) - n, err := io.ReadFull(s.Reader, octet) - - if err != nil && err != io.ErrUnexpectedEOF { - return "", util.WrapErrorf(err, util.NetworkError, "error reading octet string from sync scanner") - } else if err == io.ErrUnexpectedEOF { - return "", errIncompleteMessage("octet", n, 4) - } - - return string(octet), nil -} func (s *realSyncScanner) ReadInt32() (int32, error) { - var value int32 - err := binary.Read(s.Reader, binary.LittleEndian, &value) - return value, util.WrapErrorf(err, util.NetworkError, "error reading int from sync scanner") + value, err := readInt32(s.Reader) + return int32(value), util.WrapErrorf(err, util.NetworkError, "error reading int from sync scanner") } func (s *realSyncScanner) ReadFileMode() (os.FileMode, error) { var value uint32 diff --git a/wire/sync_test.go b/wire/sync_test.go index 1966d55..2c928fb 100644 --- a/wire/sync_test.go +++ b/wire/sync_test.go @@ -17,13 +17,6 @@ var ( someTimeEncoded = []byte{151, 208, 42, 85} ) -func TestSyncReadOctetString(t *testing.T) { - s := NewSyncScanner(strings.NewReader("helo")) - str, err := s.ReadOctetString() - assert.NoError(t, err) - assert.Equal(t, "helo", str) -} - func TestSyncSendOctetString(t *testing.T) { var buf bytes.Buffer s := NewSyncSender(&buf) diff --git a/wire/util.go b/wire/util.go index b65a881..0a0c5e1 100644 --- a/wire/util.go +++ b/wire/util.go @@ -21,28 +21,6 @@ type ErrorResponseDetails struct { // Old servers send "device not found", and newer ones "device 'serial' not found". var deviceNotFoundMessagePattern = regexp.MustCompile(`device( '.*')? not found`) -// Reads the status, and if failure, reads the message and returns it as an error. -// If the status is success, doesn't read the message. -// req is just used to populate the AdbError, and can be nil. -func ReadStatusFailureAsError(s Scanner, req string) error { - status, err := s.ReadStatus() - if err != nil { - return util.WrapErrorf(err, util.NetworkError, "error reading status for %s", req) - } - - if !status.IsSuccess() { - msg, err := s.ReadMessage() - if err != nil { - return util.WrapErrorf(err, util.NetworkError, - "server returned error for %s, but couldn't read the error message", req) - } - - return adbServerError(req, string(msg)) - } - - return nil -} - func adbServerError(request string, serverMsg string) error { var msg string if request == "" { @@ -66,6 +44,15 @@ func adbServerError(request string, serverMsg string) error { } } +// IsAdbServerErrorMatching returns true if err is an *util.Err with code AdbError and for which +// predicate returns true when passed Details.ServerMsg. +func IsAdbServerErrorMatching(err error, predicate func(string) bool) bool { + if err, ok := err.(*util.Err); ok && err.Code == util.AdbError { + return predicate(err.Details.(ErrorResponseDetails).ServerMsg) + } + return false +} + func errIncompleteMessage(description string, actual int, expected int) error { return &util.Err{ Code: util.ConnectionResetError,