diff --git a/cmd/adb/main.go b/cmd/adb/main.go index 2dfddf2..48f7655 100644 --- a/cmd/adb/main.go +++ b/cmd/adb/main.go @@ -27,9 +27,9 @@ var ( pullRemoteArg = pullCommand.Arg("remote", "Path of source file on device.").Required().String() pullLocalArg = pullCommand.Arg("local", "Path of destination file.").String() - pushCommand = kingpin.Command("push", "Push a file to the device.").Hidden() + pushCommand = kingpin.Command("push", "Push a file to the device.") pushProgressFlag = pushCommand.Flag("progress", "Show progress.").Short('p').Bool() - pushLocalArg = pushCommand.Arg("local", "Path of source file.").Required().String() + pushLocalArg = pushCommand.Arg("local", "Path of source file.").Required().File() pushRemoteArg = pushCommand.Arg("remote", "Path of destination file on device.").Required().String() ) @@ -143,47 +143,70 @@ func pull(showProgress bool, remotePath, localPath string, device goadb.DeviceDe } defer localFile.Close() - var output io.Writer - var updateProgress func() + if err := copyWithProgressAndStats(localFile, remoteFile, int(info.Size), showProgress); err != nil { + fmt.Fprintln(os.Stderr, "error pulling file:", err) + return 1 + } + return 0 +} + +func push(showProgress bool, localFile *os.File, remotePath string, device goadb.DeviceDescriptor) int { + if remotePath == "" { + fmt.Fprintln(os.Stderr, "error: must specify remote file") + kingpin.Usage() + return 1 + } + + client := goadb.NewDeviceClient(goadb.ClientConfig{}, device) + + info, err := os.Stat(localFile.Name()) + if err != nil { + fmt.Fprintf(os.Stderr, "error reading local file %s: %s\n", localFile.Name(), err) + return 1 + } + + writer, err := client.OpenWrite(remotePath, info.Mode(), info.ModTime()) + if err != nil { + fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, err) + return 1 + } + defer writer.Close() + + if err := copyWithProgressAndStats(writer, localFile, int(info.Size()), showProgress); err != nil { + fmt.Fprintln(os.Stderr, "error pushing file:", err) + return 1 + } + return 0 +} + +func copyWithProgressAndStats(dst io.Writer, src io.Reader, size int, showProgress bool) error { + var progress *pb.ProgressBar if showProgress { - output, updateProgress = createProgressBarWriter(int(info.Size), localFile) - } else { - output = localFile - updateProgress = func() {} + progress = pb.New(size) + progress.SetUnits(pb.U_BYTES) + progress.SetRefreshRate(100 * time.Millisecond) + progress.ShowSpeed = true + progress.ShowPercent = true + progress.ShowTimeLeft = true + progress.Start() + dst = io.MultiWriter(dst, progress) } startTime := time.Now() - copied, err := io.Copy(output, remoteFile) + copied, err := io.Copy(dst, src) - // Force progress update if the transfer was really fast. - updateProgress() + if progress != nil { + // Force progress update if the transfer was really fast. + progress.Update() + } if err != nil { - fmt.Fprintln(os.Stderr, "error pulling file:", err) - return 1 + return err } duration := time.Now().Sub(startTime) rate := int64(float64(copied) / duration.Seconds()) fmt.Printf("%d B/s (%d bytes in %s)\n", rate, copied, duration) - return 0 -} -func push(showProgress bool, localPath, remotePath string, device goadb.DeviceDescriptor) int { - fmt.Fprintln(os.Stderr, "not implemented") - return 1 -} - -func createProgressBarWriter(size int, w io.Writer) (progressWriter io.Writer, update func()) { - progress := pb.New(size) - progress.SetUnits(pb.U_BYTES) - progress.SetRefreshRate(100 * time.Millisecond) - progress.ShowSpeed = true - progress.ShowPercent = true - progress.ShowTimeLeft = true - progress.Start() - - progressWriter = io.MultiWriter(w, progress) - update = progress.Update - return + return nil } diff --git a/device_client.go b/device_client.go index d69102f..ed2598c 100644 --- a/device_client.go +++ b/device_client.go @@ -3,12 +3,18 @@ package goadb import ( "fmt" "io" + "os" "strings" + "time" "github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/wire" ) +// MtimeOfClose should be passed to OpenWrite to set the file modification time to the time the Close +// method is called. +var MtimeOfClose = time.Time{} + // DeviceClient communicates with a specific Android device. type DeviceClient struct { config ClientConfig @@ -171,6 +177,20 @@ func (c *DeviceClient) OpenRead(path string) (io.ReadCloser, error) { return reader, wrapClientError(err, c, "OpenRead(%s)", path) } +// OpenWrite opens the file at path on the device, creating it with the permissions specified +// by perms if necessary, and returns a writer that writes to the file. +// The files modification time will be set to mtime when the WriterCloser is closed. The zero value +// is TimeOfClose, which will use the time the Close method is called as the modification time. +func (c *DeviceClient) OpenWrite(path string, perms os.FileMode, mtime time.Time) (io.WriteCloser, error) { + conn, err := c.getSyncConn() + if err != nil { + return nil, wrapClientError(err, c, "OpenWrite(%s)", path) + } + + writer, err := sendFile(conn, path, perms, mtime) + return writer, wrapClientError(err, c, "OpenWrite(%s)", path) +} + // getAttribute returns the first message returned by the server by running // :, where host-prefix is determined from the DeviceDescriptor. func (c *DeviceClient) getAttribute(attr string) (string, error) { diff --git a/sync_client.go b/sync_client.go index c3497d6..75dfdd4 100644 --- a/sync_client.go +++ b/sync_client.go @@ -1,4 +1,3 @@ -// TODO(z): Implement send. package goadb import ( @@ -16,7 +15,7 @@ func stat(conn *wire.SyncConn, path string) (*DirEntry, error) { if err := conn.SendOctetString("STAT"); err != nil { return nil, err } - if err := conn.SendString(path); err != nil { + if err := conn.SendBytes([]byte(path)); err != nil { return nil, err } @@ -35,7 +34,7 @@ func listDirEntries(conn *wire.SyncConn, path string) (entries *DirEntries, err if err = conn.SendOctetString("LIST"); err != nil { return } - if err = conn.SendString(path); err != nil { + if err = conn.SendBytes([]byte(path)); err != nil { return } @@ -46,12 +45,29 @@ func receiveFile(conn *wire.SyncConn, path string) (io.ReadCloser, error) { if err := conn.SendOctetString("RECV"); err != nil { return nil, err } - if err := conn.SendString(path); err != nil { + if err := conn.SendBytes([]byte(path)); err != nil { return nil, err } return newSyncFileReader(conn) } +// sendFile returns a WriteCloser than will write to the file at path on device. +// The file will be created with permissions specified by mode. +// The file's modified time will be set to mtime, unless mtime is 0, in which case the time the writer is +// closed will be used. +func sendFile(conn *wire.SyncConn, path string, mode os.FileMode, mtime time.Time) (io.WriteCloser, error) { + if err := conn.SendOctetString("SEND"); err != nil { + return nil, err + } + + pathAndMode := encodePathAndMode(path, mode) + if err := conn.SendBytes(pathAndMode); err != nil { + return nil, err + } + + return newSyncFileWriter(conn, mtime), nil +} + func readStat(s wire.SyncScanner) (entry *DirEntry, err error) { mode, err := s.ReadFileMode() if err != nil { diff --git a/sync_file_writer.go b/sync_file_writer.go new file mode 100644 index 0000000..35130df --- /dev/null +++ b/sync_file_writer.go @@ -0,0 +1,76 @@ +package goadb + +import ( + "fmt" + "io" + "os" + "time" + + "github.com/zach-klippenstein/goadb/util" + "github.com/zach-klippenstein/goadb/wire" +) + +// syncFileWriter wraps a SyncConn that has requested to send a file. +type syncFileWriter struct { + // The modification time to write in the footer. + // If 0, use the current time. + mtime time.Time + + // Reader used to read data from the adb connection. + sender wire.SyncSender +} + +var _ io.WriteCloser = &syncFileWriter{} + +func newSyncFileWriter(s wire.SyncSender, mtime time.Time) io.WriteCloser { + return &syncFileWriter{ + mtime: mtime, + sender: s, + } +} + +/* +encodePathAndMode encodes a path and file mode as required for starting a send file stream. + +From https://android.googlesource.com/platform/system/core/+/master/adb/SYNC.TXT: + The remote file name is split into two parts separated by the last + comma (","). The first part is the actual path, while the second is a decimal + encoded file mode containing the permissions of the file on device. +*/ +func encodePathAndMode(path string, mode os.FileMode) []byte { + return []byte(fmt.Sprintf("%s,%d", path, uint32(mode.Perm()))) +} + +// Write writes the min of (len(buf), 64k). +func (w *syncFileWriter) Write(buf []byte) (n int, err error) { + // Writes < 64k have a one-to-one mapping to chunks. + // If buffer is larger than the max, we'll return the max size and leave it up to the + // caller to handle correctly. + if len(buf) > wire.SyncMaxChunkSize { + buf = buf[:wire.SyncMaxChunkSize] + } + + if err := w.sender.SendOctetString(wire.StatusSyncData); err != nil { + return 0, err + } + if err := w.sender.SendBytes(buf); err != nil { + return 0, err + } + + return len(buf), nil +} + +func (w *syncFileWriter) Close() error { + if w.mtime.IsZero() { + w.mtime = time.Now() + } + + if err := w.sender.SendOctetString(wire.StatusSyncDone); err != nil { + return util.WrapErrf(err, "error closing file writer") + } + if err := w.sender.SendTime(w.mtime); err != nil { + return util.WrapErrf(err, "error writing file modification time") + } + + return util.WrapErrf(w.sender.Close(), "error closing FileWriter") +} diff --git a/sync_file_writer_test.go b/sync_file_writer_test.go new file mode 100644 index 0000000..1317ef1 --- /dev/null +++ b/sync_file_writer_test.go @@ -0,0 +1,93 @@ +package goadb + +import ( + "bytes" + "testing" + "time" + + "encoding/binary" + "strings" + + "github.com/stretchr/testify/assert" + "github.com/zach-klippenstein/goadb/wire" +) + +func TestFileWriterWriteSingleChunk(t *testing.T) { + var buf bytes.Buffer + writer := newSyncFileWriter(wire.NewSyncSender(&buf), MtimeOfClose) + + n, err := writer.Write([]byte("hello")) + assert.NoError(t, err) + assert.Equal(t, 5, n) + + assert.Equal(t, "DATA\005\000\000\000hello", buf.String()) +} + +func TestFileWriterWriteMultiChunk(t *testing.T) { + var buf bytes.Buffer + writer := newSyncFileWriter(wire.NewSyncSender(&buf), MtimeOfClose) + + n, err := writer.Write([]byte("hello")) + assert.NoError(t, err) + assert.Equal(t, 5, n) + + n, err = writer.Write([]byte(" world")) + assert.NoError(t, err) + assert.Equal(t, 6, n) + + assert.Equal(t, "DATA\005\000\000\000helloDATA\006\000\000\000 world", buf.String()) +} + +func TestFileWriterWriteLargeChunk(t *testing.T) { + var buf bytes.Buffer + writer := newSyncFileWriter(wire.NewSyncSender(&buf), MtimeOfClose) + + data := make([]byte, wire.SyncMaxChunkSize+1) + n, err := writer.Write(data) + + assert.NoError(t, err) + assert.Equal(t, wire.SyncMaxChunkSize, n) + assert.Equal(t, 8 + wire.SyncMaxChunkSize, buf.Len()) + + expectedHeader := []byte("DATA0000") + binary.LittleEndian.PutUint32(expectedHeader[4:], wire.SyncMaxChunkSize) + assert.Equal(t, expectedHeader, buf.Bytes()[:8]) + + assert.Equal(t, string(data[:wire.SyncMaxChunkSize]), buf.String()[8:]) +} + +func TestFileWriterCloseEmpty(t *testing.T) { + var buf bytes.Buffer + mtime := time.Unix(1, 0) + writer := newSyncFileWriter(wire.NewSyncSender(&buf), mtime) + + assert.NoError(t, writer.Close()) + + assert.Equal(t, "DONE\x01\x00\x00\x00", buf.String()) +} + +func TestFileWriterWriteClose(t *testing.T) { + var buf bytes.Buffer + mtime := time.Unix(1, 0) + writer := newSyncFileWriter(wire.NewSyncSender(&buf), mtime) + + writer.Write([]byte("hello")) + assert.NoError(t, writer.Close()) + + assert.Equal(t, "DATA\005\000\000\000helloDONE\x01\x00\x00\x00", buf.String()) +} + +func TestFileWriterCloseAutoMtime(t *testing.T) { + var buf bytes.Buffer + writer := newSyncFileWriter(wire.NewSyncSender(&buf), MtimeOfClose) + + assert.NoError(t, writer.Close()) + assert.Len(t, buf.String(), 8) + assert.True(t, strings.HasPrefix(buf.String(), "DONE")) + + mtimeBytes := buf.Bytes()[4:] + mtimeActual := time.Unix(int64(binary.LittleEndian.Uint32(mtimeBytes)), 0) + + // Delta has to be a whole second since adb only supports second granularity for mtimes. + assert.WithinDuration(t, time.Now(), mtimeActual, 1*time.Second) +} diff --git a/util/error.go b/util/error.go index e065899..9246fe6 100644 --- a/util/error.go +++ b/util/error.go @@ -77,6 +77,44 @@ func WrapErrf(cause error, format string, args ...interface{}) error { } } +// CombineErrs returns an error that wraps all the non-nil errors passed to it. +// If all errors are nil, returns nil. +// If there's only one non-nil error, returns that error without wrapping. +// Else, returns an error with the message and code as passed, with the cause set to an error +// that contains all the non-nil errors and for which Error() returns the concatenation of all their messages. +func CombineErrs(msg string, code ErrCode, errs ...error) error { + var nonNilErrs []error + for _, err := range errs { + if err != nil { + nonNilErrs = append(nonNilErrs, err) + } + } + + switch len(nonNilErrs) { + case 0: + return nil + case 1: + return nonNilErrs[0] + default: + return WrapErrorf(multiError(nonNilErrs), code, "%s", msg) + } +} + +type multiError []error + +func (errs multiError) Error() string { + var buf bytes.Buffer + fmt.Fprintf(&buf, "%d errors: [", len(errs)) + for i, err := range errs { + buf.WriteString(err.Error()) + if i < len(errs)-1 { + buf.WriteString(" ∪ ") + } + } + buf.WriteRune(']') + return buf.String() +} + /* WrapErrorf returns an *Err that wraps another arbitrary error with an ErrCode and a message. diff --git a/util/error_test.go b/util/error_test.go index 2e629db..bc902d6 100644 --- a/util/error_test.go +++ b/util/error_test.go @@ -24,3 +24,19 @@ caused by err3` assert.Equal(t, expected, ErrorWithCauseChain(err)) } + +func TestCombineErrors(t *testing.T) { + assert.NoError(t, CombineErrs("hello", AdbError)) + assert.NoError(t, CombineErrs("hello", AdbError, nil, nil)) + + err1 := errors.New("lulz") + err2 := errors.New("fail") + + err := CombineErrs("hello", AdbError, nil, err1, nil) + assert.EqualError(t, err, "lulz") + + err = CombineErrs("hello", AdbError, err1, err2) + assert.EqualError(t, err, "AdbError: hello") + assert.Equal(t, `AdbError: hello +caused by 2 errors: [lulz ∪ fail]`, ErrorWithCauseChain(err)) +} diff --git a/wire/sync_conn.go b/wire/sync_conn.go index 521efc0..1752a09 100644 --- a/wire/sync_conn.go +++ b/wire/sync_conn.go @@ -1,9 +1,10 @@ -// TODO(z): Write SyncSender.SendBytes(). package wire +import "github.com/zach-klippenstein/goadb/util" + const ( // Chunks cannot be longer than 64k. - MaxChunkSize = 64 * 1024 + SyncMaxChunkSize = 64 * 1024 ) /* @@ -28,3 +29,9 @@ type SyncConn struct { SyncScanner SyncSender } + +// Close closes both the sender and the scanner, and returns any errors. +func (c SyncConn) Close() error { + return util.CombineErrs("error closing SyncConn", util.NetworkError, + c.SyncScanner.Close(), c.SyncSender.Close()) +} diff --git a/wire/sync_sender.go b/wire/sync_sender.go index 15dfde8..9189080 100644 --- a/wire/sync_sender.go +++ b/wire/sync_sender.go @@ -10,14 +10,17 @@ import ( ) type SyncSender interface { + io.Closer + // SendOctetString sends a 4-byte string. SendOctetString(string) error SendInt32(int32) error SendFileMode(os.FileMode) error SendTime(time.Time) error - // Sends len(bytes) as an octet, followed by bytes. - SendString(str string) error + // Sends len(data) as an octet, followed by the bytes. + // If data is bigger than SyncMaxChunkSize, it returns an assertion error. + SendBytes(data []byte) error } type realSyncSender struct { @@ -54,17 +57,24 @@ func (s *realSyncSender) SendTime(t time.Time) error { util.NetworkError, "error sending time on sync sender") } -func (s *realSyncSender) SendString(str string) error { - length := len(str) - if length > MaxChunkSize { +func (s *realSyncSender) SendBytes(data []byte) error { + length := len(data) + if length > SyncMaxChunkSize { // This limit might not apply to filenames, but it's big enough // that I don't think it will be a problem. - return util.AssertionErrorf("str must be <= %d in length", MaxChunkSize) + return util.AssertionErrorf("data must be <= %d in length", SyncMaxChunkSize) } if err := s.SendInt32(int32(length)); err != nil { - return util.WrapErrorf(err, util.NetworkError, "error sending string length on sync sender") + return util.WrapErrorf(err, util.NetworkError, "error sending data length on sync sender") } - return util.WrapErrorf(writeFully(s.Writer, []byte(str)), - util.NetworkError, "error sending string on sync sender") + return util.WrapErrorf(writeFully(s.Writer, data), + util.NetworkError, "error sending data on sync sender") +} + +func (s *realSyncSender) Close() error { + if closer, ok := s.Writer.(io.Closer); ok { + return util.WrapErrorf(closer.Close(), util.NetworkError, "error closing sync sender") + } + return nil } diff --git a/wire/sync_test.go b/wire/sync_test.go index 2c928fb..5321639 100644 --- a/wire/sync_test.go +++ b/wire/sync_test.go @@ -60,10 +60,10 @@ func TestSyncReadStringTooShort(t *testing.T) { assert.Equal(t, errIncompleteMessage("bytes", 1, 5), err) } -func TestSyncSendString(t *testing.T) { +func TestSyncSendBytes(t *testing.T) { var buf bytes.Buffer s := NewSyncSender(&buf) - err := s.SendString("hello") + err := s.SendBytes([]byte("hello")) assert.NoError(t, err) assert.Equal(t, "\005\000\000\000hello", buf.String()) }