From d0b99d9f576ace29b15b6c36e6cc83103501e924 Mon Sep 17 00:00:00 2001 From: Zach Klippenstein Date: Fri, 24 Apr 2015 08:20:47 -0500 Subject: [PATCH] Parse more bits out of file modes in sync mode. --- cmd/demo/demo.go | 2 +- device_client.go | 4 +- dir_entries.go | 10 +++ host_client.go | 4 + host_client_test.go | 2 +- sync_client.go | 14 ---- sync_file_reader_test.go | 4 +- wire/conn.go | 20 +++-- wire/dialer.go | 1 - wire/filemode.go | 33 ++++++++ wire/scanner.go | 11 ++- wire/scanner_test.go | 16 +++- wire/sender.go | 11 ++- wire/sender_test.go | 16 +++- wire/sync.go | 168 --------------------------------------- wire/sync_scanner.go | 112 ++++++++++++++++++++++++++ wire/sync_sender.go | 61 ++++++++++++++ 17 files changed, 279 insertions(+), 210 deletions(-) create mode 100644 wire/filemode.go create mode 100644 wire/sync_scanner.go create mode 100644 wire/sync_sender.go diff --git a/cmd/demo/demo.go b/cmd/demo/demo.go index 5ec45ac..296fd42 100644 --- a/cmd/demo/demo.go +++ b/cmd/demo/demo.go @@ -16,7 +16,7 @@ var port = flag.Int("p", wire.AdbPort, "") func main() { flag.Parse() - client := adb.NewHostClientDialer(wire.NewDialer("", *port)) + client := adb.NewHostClientPort(*port) fmt.Println("Starting server…") client.StartServer() diff --git a/device_client.go b/device_client.go index 6bba1e6..6c59d47 100644 --- a/device_client.go +++ b/device_client.go @@ -8,9 +8,7 @@ import ( "github.com/zach-klippenstein/goadb/wire" ) -/* -DeviceClient communicates with a specific Android device. -*/ +// DeviceClient communicates with a specific Android device. type DeviceClient struct { dialer nilSafeDialer descriptor *DeviceDescriptor diff --git a/dir_entries.go b/dir_entries.go index 6603469..4e49884 100644 --- a/dir_entries.go +++ b/dir_entries.go @@ -2,10 +2,20 @@ package goadb import ( "fmt" + "os" + "time" "github.com/zach-klippenstein/goadb/wire" ) +// DirEntry holds information about a directory entry on a device. +type DirEntry struct { + Name string + Mode os.FileMode + Size int32 + ModifiedAt time.Time +} + // DirEntries iterates over directory entries. type DirEntries struct { scanner wire.SyncScanner diff --git a/host_client.go b/host_client.go index dbecccc..e908f66 100644 --- a/host_client.go +++ b/host_client.go @@ -28,6 +28,10 @@ func NewHostClient() *HostClient { return NewHostClientDialer(nil) } +func NewHostClientPort(port int) *HostClient { + return NewHostClientDialer(wire.NewDialer("", port)) +} + func NewHostClientDialer(d wire.Dialer) *HostClient { return &HostClient{nilSafeDialer{d}} } diff --git a/host_client_test.go b/host_client_test.go index 23a3451..7638f85 100644 --- a/host_client_test.go +++ b/host_client_test.go @@ -35,7 +35,7 @@ type MockServer struct { } func (s *MockServer) Dial() (*wire.Conn, error) { - return wire.NewConn(s, s, s.Close), nil + return wire.NewConn(s, s), nil } func (s *MockServer) ReadStatus() (wire.StatusCode, error) { diff --git a/sync_client.go b/sync_client.go index 61b063d..6b3e308 100644 --- a/sync_client.go +++ b/sync_client.go @@ -4,24 +4,10 @@ package goadb import ( "fmt" "io" - "os" - "time" "github.com/zach-klippenstein/goadb/wire" ) -/* -DirEntry holds information about a directory entry on a device. - -Unfortunately, adb doesn't seem to set the directory bit for directories. -*/ -type DirEntry struct { - Name string - Mode os.FileMode - Size int32 - ModifiedAt time.Time -} - func stat(conn *wire.SyncConn, path string) (*DirEntry, error) { if err := conn.SendOctetString("STAT"); err != nil { return nil, err diff --git a/sync_file_reader_test.go b/sync_file_reader_test.go index cd7bfc5..67033ba 100644 --- a/sync_file_reader_test.go +++ b/sync_file_reader_test.go @@ -16,7 +16,7 @@ func TestReadNextChunk(t *testing.T) { // Read 1st chunk reader, err := readNextChunk(s) assert.NoError(t, err) - assert.Equal(t, 6, reader.(*io.LimitedReader).N) + assert.Equal(t, int64(6), reader.(*io.LimitedReader).N) buf := make([]byte, 10) n, err := reader.Read(buf) assert.NoError(t, err) @@ -26,7 +26,7 @@ func TestReadNextChunk(t *testing.T) { // Read 2nd chunk reader, err = readNextChunk(s) assert.NoError(t, err) - assert.Equal(t, 5, reader.(*io.LimitedReader).N) + assert.Equal(t, int64(5), reader.(*io.LimitedReader).N) buf = make([]byte, 10) n, err = reader.Read(buf) assert.NoError(t, err) diff --git a/wire/conn.go b/wire/conn.go index be49829..824b453 100644 --- a/wire/conn.go +++ b/wire/conn.go @@ -26,19 +26,10 @@ You should still always call Close() when you're done with the connection. type Conn struct { Scanner Sender - closer func() error } -func NewConn(scanner Scanner, sender Sender, closer func() error) *Conn { - return &Conn{scanner, sender, closer} -} - -// Close closes the underlying connection. -func (c *Conn) Close() error { - if c.closer != nil { - return c.closer() - } - return nil +func NewConn(scanner Scanner, sender Sender) *Conn { + return &Conn{scanner, sender} } // NewSyncConn returns connection that can operate in sync mode. @@ -64,3 +55,10 @@ func (conn *Conn) RoundTripSingleResponse(req []byte) (resp []byte, err error) { return conn.ReadMessage() } + +func (conn *Conn) Close() error { + if err := conn.Sender.Close(); err != nil { + return err + } + return conn.Scanner.Close() +} diff --git a/wire/dialer.go b/wire/dialer.go index ccf883d..bac2683 100644 --- a/wire/dialer.go +++ b/wire/dialer.go @@ -48,7 +48,6 @@ func (d *netDialer) Dial() (*Conn, error) { conn := &Conn{ Scanner: NewScanner(netConn), Sender: NewSender(netConn), - closer: netConn.Close, } // Prevent leaking the network connection, not sure if TCPConn does this itself. diff --git a/wire/filemode.go b/wire/filemode.go new file mode 100644 index 0000000..9c55250 --- /dev/null +++ b/wire/filemode.go @@ -0,0 +1,33 @@ +package wire + +import "os" + +// ADB file modes seem to only be 16 bits. +// Values are taken from http://linux.die.net/include/bits/stat.h. +const ( + ModeDir uint32 = 0040000 + ModeSymlink = 0120000 + ModeSocket = 0140000 + ModeFifo = 0010000 + ModeCharDevice = 0020000 +) + +func ParseFileModeFromAdb(modeFromSync uint32) (filemode os.FileMode) { + // The ADB filemode uses the permission bits defined in Go's os package, but + // we need to parse the other bits manually. + switch { + case modeFromSync&ModeSymlink == ModeSymlink: + filemode = os.ModeSymlink + case modeFromSync&ModeDir == ModeDir: + filemode = os.ModeDir + case modeFromSync&ModeSocket == ModeSocket: + filemode = os.ModeSocket + case modeFromSync&ModeFifo == ModeFifo: + filemode = os.ModeNamedPipe + case modeFromSync&ModeCharDevice == ModeCharDevice: + filemode = os.ModeCharDevice + } + + filemode |= os.FileMode(modeFromSync).Perm() + return +} diff --git a/wire/scanner.go b/wire/scanner.go index e8ffa32..48a972c 100644 --- a/wire/scanner.go +++ b/wire/scanner.go @@ -29,14 +29,17 @@ type Scanner interface { ReadStatus() (StatusCode, error) ReadMessage() ([]byte, error) ReadUntilEof() ([]byte, error) + NewSyncScanner() SyncScanner + + Close() error } type realScanner struct { - reader io.Reader + reader io.ReadCloser } -func NewScanner(r io.Reader) Scanner { +func NewScanner(r io.ReadCloser) Scanner { return &realScanner{r} } @@ -108,6 +111,10 @@ func ReadStatusFailureAsError(s Scanner, req []byte) error { return nil } +func (s *realScanner) Close() error { + return s.reader.Close() +} + func (s *realScanner) readLength() (int, error) { lengthHex := make([]byte, 4) n, err := io.ReadFull(s.reader, lengthHex) diff --git a/wire/scanner_test.go b/wire/scanner_test.go index cf958e2..20b402b 100644 --- a/wire/scanner_test.go +++ b/wire/scanner_test.go @@ -89,8 +89,10 @@ func NewScannerString(str string) *realScanner { // NewEofBuffer returns a bytes.Buffer of str that returns an EOF error // at the end of input, instead of just returning 0 bytes read. -func NewEofBuffer(str string) *bufio.Reader { - return bufio.NewReader(io.LimitReader(bytes.NewBufferString(str), int64(len(str)))) +func NewEofBuffer(str string) *TestReader { + limitReader := io.LimitReader(bytes.NewBufferString(str), int64(len(str))) + bufReader := bufio.NewReader(limitReader) + return &TestReader{bufReader} } func assertEof(t *testing.T, s *realScanner) { @@ -104,3 +106,13 @@ func assertNotEof(t *testing.T, s *realScanner) { assert.Equal(t, 1, n) assert.NoError(t, err) } + +// TestReader is a wrapper around a bufio.Reader that implements io.Closer. +type TestReader struct { + *bufio.Reader +} + +func (b *TestReader) Close() error { + // No-op. + return nil +} diff --git a/wire/sender.go b/wire/sender.go index 79444d6..d80363f 100644 --- a/wire/sender.go +++ b/wire/sender.go @@ -8,14 +8,17 @@ import ( // Sender sends messages to the server. type Sender interface { SendMessage(msg []byte) error + NewSyncSender() SyncSender + + Close() error } type realSender struct { - writer io.Writer + writer io.WriteCloser } -func NewSender(w io.Writer) Sender { +func NewSender(w io.WriteCloser) Sender { return &realSender{w} } @@ -36,4 +39,8 @@ func (s *realSender) NewSyncSender() SyncSender { return NewSyncSender(s.writer) } +func (s *realSender) Close() error { + return s.writer.Close() +} + var _ Sender = &realSender{} diff --git a/wire/sender_test.go b/wire/sender_test.go index 3b42b11..c7bd4ac 100644 --- a/wire/sender_test.go +++ b/wire/sender_test.go @@ -21,7 +21,17 @@ func TestWriteEmptyMessage(t *testing.T) { assert.Equal(t, "0000", b.String()) } -func NewTestSender() (Sender, *bytes.Buffer) { - var buf bytes.Buffer - return NewSender(&buf), &buf +func NewTestSender() (Sender, *TestWriter) { + w := new(TestWriter) + return NewSender(w), w +} + +// TestWriter is a wrapper around a bytes.Buffer that implements io.Closer. +type TestWriter struct { + bytes.Buffer +} + +func (b *TestWriter) Close() error { + // No-op. + return nil } diff --git a/wire/sync.go b/wire/sync.go index d69f63f..521efc0 100644 --- a/wire/sync.go +++ b/wire/sync.go @@ -1,14 +1,6 @@ // TODO(z): Write SyncSender.SendBytes(). package wire -import ( - "encoding/binary" - "fmt" - "io" - "os" - "time" -) - const ( // Chunks cannot be longer than 64k. MaxChunkSize = 64 * 1024 @@ -36,163 +28,3 @@ type SyncConn struct { SyncScanner SyncSender } - -func (c *SyncConn) Close() error { - return c.SyncScanner.Close() -} - -type SyncScanner interface { - // ReadOctetString reads a 4-byte string. - ReadOctetString() (string, error) - ReadInt32() (int32, error) - ReadFileMode() (os.FileMode, error) - ReadTime() (time.Time, error) - - // Reads an octet length, followed by length bytes. - ReadString() (string, error) - - // Reads an octet length, and returns a reader that will read length - // bytes (see io.LimitReader). The returned reader should be fully - // read before reading anything off the Scanner again. - ReadBytes() (io.Reader, error) - - // Closes the underlying reader. - Close() error -} - -type SyncSender interface { - // SendOctetString sends a 4-byte string. - SendOctetString(string) error - SendInt32(int32) error - SendFileMode(os.FileMode) error - SendTime(time.Time) error - - // Sends len(bytes) as an octet, followed by bytes. - SendString(str string) error -} - -type realSyncScanner struct { - io.Reader -} - -type realSyncSender struct { - io.Writer -} - -func NewSyncScanner(r io.Reader) SyncScanner { - return &realSyncScanner{r} -} - -func NewSyncSender(w io.Writer) SyncSender { - return &realSyncSender{w} -} - -func RequireOctetString(s SyncScanner, expected string) error { - actual, err := s.ReadOctetString() - if err != nil { - return fmt.Errorf("expected to read '%s', got err: %v", expected, err) - } - if actual != expected { - return fmt.Errorf("expected to read '%s', got '%s'", expected, actual) - } - return nil -} - -func (s *realSyncScanner) ReadOctetString() (string, error) { - octet := make([]byte, 4) - n, err := io.ReadFull(s.Reader, octet) - if err != nil && err != io.ErrUnexpectedEOF { - return "", err - } else if err == io.ErrUnexpectedEOF { - return "", incompleteMessage("octet", n, 4) - } - - return string(octet), nil -} - -func (s *realSyncSender) SendOctetString(str string) error { - if len(str) != 4 { - return fmt.Errorf("octet string must be exactly 4 bytes: '%s'", str) - } - return writeFully(s.Writer, []byte(str)) -} - -func (s *realSyncScanner) ReadInt32() (int32, error) { - var value int32 - err := binary.Read(s.Reader, binary.LittleEndian, &value) - return value, err -} - -func (s *realSyncSender) SendInt32(val int32) error { - return binary.Write(s.Writer, binary.LittleEndian, val) -} - -func (s *realSyncScanner) ReadFileMode() (os.FileMode, error) { - var value uint32 - err := binary.Read(s.Reader, binary.LittleEndian, &value) - return os.FileMode(value), err -} - -func (s *realSyncSender) SendFileMode(mode os.FileMode) error { - return binary.Write(s.Writer, binary.LittleEndian, mode) -} - -func (s *realSyncScanner) ReadTime() (time.Time, error) { - seconds, err := s.ReadInt32() - if err != nil { - return time.Time{}, err - } - - return time.Unix(int64(seconds), 0).UTC(), nil -} - -func (s *realSyncSender) SendTime(t time.Time) error { - return s.SendInt32(int32(t.Unix())) -} - -func (s *realSyncScanner) ReadString() (string, error) { - length, err := s.ReadInt32() - if err != nil { - return "", err - } - - bytes := make([]byte, length) - n, err := io.ReadFull(s.Reader, bytes) - if err != nil && err != io.ErrUnexpectedEOF { - return "", err - } else if err == io.ErrUnexpectedEOF { - return "", incompleteMessage("bytes", n, int(length)) - } - - return string(bytes), nil -} - -func (s *realSyncSender) SendString(str string) error { - length := len(str) - if length > MaxChunkSize { - // This limit might not apply to filenames, but it's big enough - // that I don't think it will be a problem. - return fmt.Errorf("str must be <= %d in length", MaxChunkSize) - } - - if err := s.SendInt32(int32(length)); err != nil { - return err - } - return writeFully(s.Writer, []byte(str)) -} - -func (s *realSyncScanner) ReadBytes() (io.Reader, error) { - length, err := s.ReadInt32() - if err != nil { - return nil, err - } - - return io.LimitReader(s.Reader, int64(length)), nil -} - -func (s *realSyncScanner) Close() error { - if closer, ok := s.Reader.(io.Closer); ok { - return closer.Close() - } - return nil -} diff --git a/wire/sync_scanner.go b/wire/sync_scanner.go new file mode 100644 index 0000000..34c8a8b --- /dev/null +++ b/wire/sync_scanner.go @@ -0,0 +1,112 @@ +package wire + +import ( + "encoding/binary" + "fmt" + "io" + "os" + "time" +) + +type SyncScanner interface { + // ReadOctetString reads a 4-byte string. + ReadOctetString() (string, error) + ReadInt32() (int32, error) + ReadFileMode() (os.FileMode, error) + ReadTime() (time.Time, error) + + // Reads an octet length, followed by length bytes. + ReadString() (string, error) + + // Reads an octet length, and returns a reader that will read length + // bytes (see io.LimitReader). The returned reader should be fully + // read before reading anything off the Scanner again. + ReadBytes() (io.Reader, error) + + // Closes the underlying reader. + Close() error +} + +type realSyncScanner struct { + io.Reader +} + +func NewSyncScanner(r io.Reader) SyncScanner { + return &realSyncScanner{r} +} + +func RequireOctetString(s SyncScanner, expected string) error { + actual, err := s.ReadOctetString() + if err != nil { + return fmt.Errorf("expected to read '%s', got err: %v", expected, err) + } + if actual != expected { + return fmt.Errorf("expected to read '%s', got '%s'", expected, actual) + } + return nil +} + +func (s *realSyncScanner) ReadOctetString() (string, error) { + octet := make([]byte, 4) + n, err := io.ReadFull(s.Reader, octet) + if err != nil && err != io.ErrUnexpectedEOF { + return "", err + } else if err == io.ErrUnexpectedEOF { + return "", incompleteMessage("octet", n, 4) + } + + return string(octet), nil +} +func (s *realSyncScanner) ReadInt32() (int32, error) { + var value int32 + err := binary.Read(s.Reader, binary.LittleEndian, &value) + return value, err +} +func (s *realSyncScanner) ReadFileMode() (filemode os.FileMode, err error) { + var value uint32 + err = binary.Read(s.Reader, binary.LittleEndian, &value) + if err == nil { + filemode = ParseFileModeFromAdb(value) + } + return +} +func (s *realSyncScanner) ReadTime() (time.Time, error) { + seconds, err := s.ReadInt32() + if err != nil { + return time.Time{}, err + } + + return time.Unix(int64(seconds), 0).UTC(), nil +} + +func (s *realSyncScanner) ReadString() (string, error) { + length, err := s.ReadInt32() + if err != nil { + return "", err + } + + bytes := make([]byte, length) + n, err := io.ReadFull(s.Reader, bytes) + if err != nil && err != io.ErrUnexpectedEOF { + return "", err + } else if err == io.ErrUnexpectedEOF { + return "", incompleteMessage("bytes", n, int(length)) + } + + return string(bytes), nil +} +func (s *realSyncScanner) ReadBytes() (io.Reader, error) { + length, err := s.ReadInt32() + if err != nil { + return nil, err + } + + return io.LimitReader(s.Reader, int64(length)), nil +} + +func (s *realSyncScanner) Close() error { + if closer, ok := s.Reader.(io.Closer); ok { + return closer.Close() + } + return nil +} diff --git a/wire/sync_sender.go b/wire/sync_sender.go new file mode 100644 index 0000000..345f91e --- /dev/null +++ b/wire/sync_sender.go @@ -0,0 +1,61 @@ +package wire + +import ( + "encoding/binary" + "fmt" + "io" + "os" + "time" +) + +type SyncSender interface { + // SendOctetString sends a 4-byte string. + SendOctetString(string) error + SendInt32(int32) error + SendFileMode(os.FileMode) error + SendTime(time.Time) error + + // Sends len(bytes) as an octet, followed by bytes. + SendString(str string) error +} + +type realSyncSender struct { + io.Writer +} + +func NewSyncSender(w io.Writer) SyncSender { + return &realSyncSender{w} +} + +func (s *realSyncSender) SendOctetString(str string) error { + if len(str) != 4 { + return fmt.Errorf("octet string must be exactly 4 bytes: '%s'", str) + } + return writeFully(s.Writer, []byte(str)) +} + +func (s *realSyncSender) SendInt32(val int32) error { + return binary.Write(s.Writer, binary.LittleEndian, val) +} + +func (s *realSyncSender) SendFileMode(mode os.FileMode) error { + return binary.Write(s.Writer, binary.LittleEndian, mode) +} + +func (s *realSyncSender) SendTime(t time.Time) error { + return s.SendInt32(int32(t.Unix())) +} + +func (s *realSyncSender) SendString(str string) error { + length := len(str) + if length > MaxChunkSize { + // This limit might not apply to filenames, but it's big enough + // that I don't think it will be a problem. + return fmt.Errorf("str must be <= %d in length", MaxChunkSize) + } + + if err := s.SendInt32(int32(length)); err != nil { + return err + } + return writeFully(s.Writer, []byte(str)) +}