diff --git a/dialer.go b/dialer.go index ebb5c12..83334a6 100644 --- a/dialer.go +++ b/dialer.go @@ -2,6 +2,7 @@ package goadb import ( "fmt" + "io" "net" "runtime" @@ -66,17 +67,22 @@ func (d *netDialer) Dial() (*wire.Conn, error) { } } - conn := &wire.Conn{ - Scanner: wire.NewScanner(netConn), - Sender: wire.NewSender(netConn), - } + // net.Conn can't be closed more than once, but wire.Conn will try to close both sender and scanner + // so we need to wrap it to make it safe. + safeConn := wire.MultiCloseable(netConn) // Prevent leaking the network connection, not sure if TCPConn does this itself. - runtime.SetFinalizer(netConn, func(conn *net.TCPConn) { + // Note that the network connection may still be in use after the conn isn't (scanners/senders + // can give their underlying connections to other scanner/sender types), so we can't + // set the finalizer on conn. + runtime.SetFinalizer(safeConn, func(conn io.ReadWriteCloser) { conn.Close() }) - return conn, nil + return &wire.Conn{ + Scanner: wire.NewScanner(safeConn), + Sender: wire.NewSender(safeConn), + }, nil } func roundTripSingleResponse(d Dialer, req string) ([]byte, error) { diff --git a/sync_file_reader.go b/sync_file_reader.go index 69f474d..1aa1b89 100644 --- a/sync_file_reader.go +++ b/sync_file_reader.go @@ -1,7 +1,6 @@ package goadb import ( - "fmt" "io" "github.com/zach-klippenstein/goadb/util" @@ -15,6 +14,9 @@ type syncFileReader struct { // Reader for the current chunk only. chunkReader io.Reader + + // False until the DONE chunk is encountered. + eof bool } var _ io.ReadCloser = &syncFileReader{} @@ -26,18 +28,30 @@ func newSyncFileReader(s wire.SyncScanner) (r io.ReadCloser, err error) { // Read the header for the first chunk to consume any errors. if _, err = r.Read([]byte{}); err != nil { - r.Close() - return nil, err + if err == io.EOF { + // EOF means the file was empty. This still means the file was opened successfully, + // and the next time the caller does a read they'll get the EOF and handle it themselves. + err = nil + } else { + r.Close() + return nil, err + } } return } func (r *syncFileReader) Read(buf []byte) (n int, err error) { + if r.eof { + return 0, io.EOF + } + if r.chunkReader == nil { chunkReader, err := readNextChunk(r.scanner) if err != nil { - // If this is EOF, we've read the last chunk. - // Either way, we want to pass it up to the caller. + if err == io.EOF { + // We just read the last chunk, set our flag before passing it up. + r.eof = true + } return 0, err } r.chunkReader = chunkReader @@ -82,7 +96,7 @@ func readNextChunk(r wire.SyncScanner) (io.Reader, error) { case wire.StatusSyncDone: return nil, io.EOF default: - return nil, fmt.Errorf("expected chunk id '%s' or '%s', but got '%s'", + return nil, util.Errorf(util.AssertionError, "expected chunk id '%s' or '%s', but got '%s'", wire.StatusSyncData, wire.StatusSyncDone, []byte(status)) } } diff --git a/sync_file_reader_test.go b/sync_file_reader_test.go index d2b2a55..83da0e8 100644 --- a/sync_file_reader_test.go +++ b/sync_file_reader_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/wire" + "io/ioutil" ) func TestReadNextChunk(t *testing.T) { @@ -44,7 +45,7 @@ func TestReadNextChunkInvalidChunkId(t *testing.T) { // Read 1st chunk _, err := readNextChunk(s) - assert.EqualError(t, err, "expected chunk id 'DATA' or 'DONE', but got 'ATAD'") + assert.EqualError(t, err, "AssertionError: expected chunk id 'DATA' or 'DONE', but got 'ATAD'") } func TestReadMultipleCalls(t *testing.T) { @@ -91,6 +92,20 @@ func TestReadError(t *testing.T) { assert.EqualError(t, err, "AdbError: server error for read-chunk request: fail ({Request:read-chunk ServerMsg:fail})") } +func TestReadEmpty(t *testing.T) { + s := wire.NewSyncScanner(strings.NewReader( + "DONE")) + r, err := newSyncFileReader(s) + assert.NoError(t, err) + + // Multiple read calls that return EOF is a valid case. + for i := 0; i < 5; i++ { + data, err := ioutil.ReadAll(r) + assert.NoError(t, err) + assert.Empty(t, data) + } +} + func TestReadErrorNotFound(t *testing.T) { s := wire.NewSyncScanner(strings.NewReader( "FAIL\031\000\000\000No such file or directory")) diff --git a/sync_file_writer.go b/sync_file_writer.go index 35130df..80c24a6 100644 --- a/sync_file_writer.go +++ b/sync_file_writer.go @@ -43,21 +43,31 @@ func encodePathAndMode(path string, mode os.FileMode) []byte { // Write writes the min of (len(buf), 64k). func (w *syncFileWriter) Write(buf []byte) (n int, err error) { - // Writes < 64k have a one-to-one mapping to chunks. - // If buffer is larger than the max, we'll return the max size and leave it up to the - // caller to handle correctly. - if len(buf) > wire.SyncMaxChunkSize { - buf = buf[:wire.SyncMaxChunkSize] + written := 0 + + // If buf > 64k we'll have to send multiple chunks. + // TODO Refactor this into something that can coalesce smaller writes into a single chukn. + for len(buf) > 0 { + // Writes < 64k have a one-to-one mapping to chunks. + // If buffer is larger than the max, we'll return the max size and leave it up to the + // caller to handle correctly. + partialBuf := buf + if len(partialBuf) > wire.SyncMaxChunkSize { + partialBuf = partialBuf[:wire.SyncMaxChunkSize] + } + + if err := w.sender.SendOctetString(wire.StatusSyncData); err != nil { + return written, err + } + if err := w.sender.SendBytes(partialBuf); err != nil { + return written, err + } + + written += len(partialBuf) + buf = buf[len(partialBuf):] } - if err := w.sender.SendOctetString(wire.StatusSyncData); err != nil { - return 0, err - } - if err := w.sender.SendBytes(buf); err != nil { - return 0, err - } - - return len(buf), nil + return written, nil } func (w *syncFileWriter) Close() error { @@ -66,7 +76,7 @@ func (w *syncFileWriter) Close() error { } if err := w.sender.SendOctetString(wire.StatusSyncDone); err != nil { - return util.WrapErrf(err, "error closing file writer") + return util.WrapErrf(err, "error sending done chunk to close stream") } if err := w.sender.SendTime(w.mtime); err != nil { return util.WrapErrf(err, "error writing file modification time") diff --git a/sync_file_writer_test.go b/sync_file_writer_test.go index 1317ef1..3b10155 100644 --- a/sync_file_writer_test.go +++ b/sync_file_writer_test.go @@ -42,18 +42,26 @@ func TestFileWriterWriteLargeChunk(t *testing.T) { var buf bytes.Buffer writer := newSyncFileWriter(wire.NewSyncSender(&buf), MtimeOfClose) + // Send just enough data to get 2 chunks. data := make([]byte, wire.SyncMaxChunkSize+1) n, err := writer.Write(data) assert.NoError(t, err) - assert.Equal(t, wire.SyncMaxChunkSize, n) - assert.Equal(t, 8 + wire.SyncMaxChunkSize, buf.Len()) + assert.Equal(t, wire.SyncMaxChunkSize+1, n) + assert.Equal(t, 8 + 8 + wire.SyncMaxChunkSize+1, buf.Len()) - expectedHeader := []byte("DATA0000") + // First header. + chunk := buf.Bytes()[:8+wire.SyncMaxChunkSize] + expectedHeader := []byte("DATA----") binary.LittleEndian.PutUint32(expectedHeader[4:], wire.SyncMaxChunkSize) - assert.Equal(t, expectedHeader, buf.Bytes()[:8]) + assert.Equal(t, expectedHeader, chunk[:8]) + assert.Equal(t, data[:wire.SyncMaxChunkSize], chunk[8:]) - assert.Equal(t, string(data[:wire.SyncMaxChunkSize]), buf.String()[8:]) + // Second header. + chunk = buf.Bytes()[wire.SyncMaxChunkSize+8:wire.SyncMaxChunkSize+8+1] + expectedHeader = []byte("DATA\000\000\000\000") + binary.LittleEndian.PutUint32(expectedHeader[4:], 1) + assert.Equal(t, expectedHeader, chunk[:8]) } func TestFileWriterCloseEmpty(t *testing.T) { diff --git a/util.go b/util.go index 6291f23..92d83ef 100644 --- a/util.go +++ b/util.go @@ -26,7 +26,7 @@ func wrapClientError(err error, client interface{}, operation string, args ...in return nil } if _, ok := err.(*util.Err); !ok { - panic("err is not a *util.Err") + panic("err is not a *util.Err: " + err.Error()) } clientType := reflect.TypeOf(client) diff --git a/util/error.go b/util/error.go index 9246fe6..684ad50 100644 --- a/util/error.go +++ b/util/error.go @@ -176,7 +176,12 @@ func ErrorWithCauseChain(err error) string { break } } - buffer.WriteString(err.Error()) + + if err != nil { + buffer.WriteString(err.Error()) + } else { + buffer.WriteString("") + } return buffer.String() } diff --git a/util/error_test.go b/util/error_test.go index bc902d6..375515a 100644 --- a/util/error_test.go +++ b/util/error_test.go @@ -23,6 +23,8 @@ caused by AssertionError: err2 caused by err3` assert.Equal(t, expected, ErrorWithCauseChain(err)) + + assert.Equal(t, "", ErrorWithCauseChain(nil)) } func TestCombineErrors(t *testing.T) { diff --git a/wire/sync_sender.go b/wire/sync_sender.go index 9189080..182678a 100644 --- a/wire/sync_sender.go +++ b/wire/sync_sender.go @@ -68,8 +68,7 @@ func (s *realSyncSender) SendBytes(data []byte) error { if err := s.SendInt32(int32(length)); err != nil { return util.WrapErrorf(err, util.NetworkError, "error sending data length on sync sender") } - return util.WrapErrorf(writeFully(s.Writer, data), - util.NetworkError, "error sending data on sync sender") + return writeFully(s.Writer, data) } func (s *realSyncSender) Close() error { diff --git a/wire/util.go b/wire/util.go index 0a0c5e1..63bb7fc 100644 --- a/wire/util.go +++ b/wire/util.go @@ -5,6 +5,8 @@ import ( "io" "regexp" + "sync" + "github.com/zach-klippenstein/goadb/util" ) @@ -80,3 +82,21 @@ func writeFully(w io.Writer, data []byte) error { } return nil } + +// MultiCloseable wraps c in a ReadWriteCloser that can be safely closed multiple times. +func MultiCloseable(c io.ReadWriteCloser) io.ReadWriteCloser { + return &multiCloseable{ReadWriteCloser: c} +} + +type multiCloseable struct { + io.ReadWriteCloser + closeOnce sync.Once + err error +} + +func (c *multiCloseable) Close() error { + c.closeOnce.Do(func() { + c.err = c.ReadWriteCloser.Close() + }) + return c.err +}