From 3e1d948164abc53143703164f8cf6ba3228848c8 Mon Sep 17 00:00:00 2001
From: Zach Klippenstein <zach.klippenstein@gmail.com>
Date: Tue, 29 Dec 2015 13:25:17 -0800
Subject: [PATCH] [cmd/adb] Support pulling/pushing from stdin and to stdout.

---
 cmd/adb/main.go | 82 ++++++++++++++++++++++++++++++++++++-------------
 1 file changed, 61 insertions(+), 21 deletions(-)

diff --git a/cmd/adb/main.go b/cmd/adb/main.go
index 48f7655..5b3c4ec 100644
--- a/cmd/adb/main.go
+++ b/cmd/adb/main.go
@@ -5,6 +5,7 @@ import (
 	"io"
 	"os"
 	"path/filepath"
+	"syscall"
 	"time"
 
 	"github.com/cheggaaa/pb"
@@ -13,6 +14,8 @@ import (
 	"gopkg.in/alecthomas/kingpin.v2"
 )
 
+const StdIoFilename = "-"
+
 var (
 	serial = kingpin.Flag("serial", "Connect to device by serial number.").Short('s').String()
 
@@ -25,11 +28,11 @@ var (
 	pullCommand      = kingpin.Command("pull", "Pull a file from the device.")
 	pullProgressFlag = pullCommand.Flag("progress", "Show progress.").Short('p').Bool()
 	pullRemoteArg    = pullCommand.Arg("remote", "Path of source file on device.").Required().String()
-	pullLocalArg     = pullCommand.Arg("local", "Path of destination file.").String()
+	pullLocalArg     = pullCommand.Arg("local", "Path of destination file. If -, will write to stdout.").String()
 
 	pushCommand      = kingpin.Command("push", "Push a file to the device.")
 	pushProgressFlag = pushCommand.Flag("progress", "Show progress.").Short('p').Bool()
-	pushLocalArg     = pushCommand.Arg("local", "Path of source file.").Required().File()
+	pushLocalArg     = pushCommand.Arg("local", "Path of source file. If -, will read from stdin.").Required().String()
 	pushRemoteArg    = pushCommand.Arg("remote", "Path of destination file on device.").Required().String()
 )
 
@@ -136,10 +139,15 @@ func pull(showProgress bool, remotePath, localPath string, device goadb.DeviceDe
 	}
 	defer remoteFile.Close()
 
-	localFile, err := os.Create(localPath)
-	if err != nil {
-		fmt.Fprintf(os.Stderr, "error opening local file %s: %s\n", localPath, err)
-		return 1
+	var localFile io.WriteCloser
+	if localPath == StdIoFilename {
+		localFile = os.Stdout
+	} else {
+		localFile, err = os.Create(localPath)
+		if err != nil {
+			fmt.Fprintf(os.Stderr, "error opening local file %s: %s\n", localPath, err)
+			return 1
+		}
 	}
 	defer localFile.Close()
 
@@ -150,44 +158,71 @@ func pull(showProgress bool, remotePath, localPath string, device goadb.DeviceDe
 	return 0
 }
 
-func push(showProgress bool, localFile *os.File, remotePath string, device goadb.DeviceDescriptor) int {
+func push(showProgress bool, localPath, remotePath string, device goadb.DeviceDescriptor) int {
 	if remotePath == "" {
 		fmt.Fprintln(os.Stderr, "error: must specify remote file")
 		kingpin.Usage()
 		return 1
 	}
 
-	client := goadb.NewDeviceClient(goadb.ClientConfig{}, device)
-
-	info, err := os.Stat(localFile.Name())
-	if err != nil {
-		fmt.Fprintf(os.Stderr, "error reading local file %s: %s\n", localFile.Name(), err)
-		return 1
+	var (
+		localFile io.ReadCloser
+		size      int
+		perms     os.FileMode
+		mtime     time.Time
+	)
+	if localPath == "" || localPath == StdIoFilename {
+		localFile = os.Stdin
+		// 0 size will hide the progress bar.
+		perms = os.FileMode(0660)
+		mtime = goadb.MtimeOfClose
+	} else {
+		var err error
+		localFile, err = os.Open(localPath)
+		if err != nil {
+			fmt.Fprintf(os.Stderr, "error opening local file %s: %s\n", localPath, err)
+			return 1
+		}
+		info, err := os.Stat(localPath)
+		if err != nil {
+			fmt.Fprintf(os.Stderr, "error reading local file %s: %s\n", localPath, err)
+			return 1
+		}
+		size = int(info.Size())
+		perms = info.Mode().Perm()
+		mtime = info.ModTime()
 	}
+	defer localFile.Close()
 
-	writer, err := client.OpenWrite(remotePath, info.Mode(), info.ModTime())
+	client := goadb.NewDeviceClient(goadb.ClientConfig{}, device)
+	writer, err := client.OpenWrite(remotePath, perms, mtime)
 	if err != nil {
 		fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, err)
 		return 1
 	}
 	defer writer.Close()
 
-	if err := copyWithProgressAndStats(writer, localFile, int(info.Size()), showProgress); err != nil {
+	if err := copyWithProgressAndStats(writer, localFile, size, showProgress); err != nil {
 		fmt.Fprintln(os.Stderr, "error pushing file:", err)
 		return 1
 	}
 	return 0
 }
 
+// copyWithProgressAndStats copies src to dst.
+// If showProgress is true and size is positive, a progress bar is shown.
+// After copying, final stats about the transfer speed and size are shown.
+// Progress and stats are printed to stderr.
 func copyWithProgressAndStats(dst io.Writer, src io.Reader, size int, showProgress bool) error {
 	var progress *pb.ProgressBar
-	if showProgress {
+	if showProgress && size > 0 {
 		progress = pb.New(size)
-		progress.SetUnits(pb.U_BYTES)
-		progress.SetRefreshRate(100 * time.Millisecond)
+		// Write to stderr in case dst is stdout.
+		progress.Output = os.Stderr
 		progress.ShowSpeed = true
 		progress.ShowPercent = true
 		progress.ShowTimeLeft = true
+		progress.SetUnits(pb.U_BYTES)
 		progress.Start()
 		dst = io.MultiWriter(dst, progress)
 	}
@@ -196,17 +231,22 @@ func copyWithProgressAndStats(dst io.Writer, src io.Reader, size int, showProgre
 	copied, err := io.Copy(dst, src)
 
 	if progress != nil {
-		// Force progress update if the transfer was really fast.
-		progress.Update()
+		progress.Finish()
 	}
 
+	if pathErr, ok := err.(*os.PathError); ok {
+		if errno, ok := pathErr.Err.(syscall.Errno); ok && errno == syscall.EPIPE {
+			// Pipe closed. Handle this like an EOF.
+			err = nil
+		}
+	}
 	if err != nil {
 		return err
 	}
 
 	duration := time.Now().Sub(startTime)
 	rate := int64(float64(copied) / duration.Seconds())
-	fmt.Printf("%d B/s (%d bytes in %s)\n", rate, copied, duration)
+	fmt.Fprintf(os.Stderr, "%d B/s (%d bytes in %s)\n", rate, copied, duration)
 
 	return nil
 }