massive commit to get umass working, only significant known bug is that changes don...
authorpete-cs <pete-cs@12e2690d-a6be-4b82-a7b7-67c4a43b65c8>
Wed, 19 Jun 2013 04:45:15 +0000 (04:45 +0000)
committerpete-cs <pete-cs@12e2690d-a6be-4b82-a7b7-67c4a43b65c8>
Wed, 19 Jun 2013 04:45:15 +0000 (04:45 +0000)
git-svn-id: https://svn.code.sf.net/p/goodfet/code/trunk@1604 12e2690d-a6be-4b82-a7b7-67c4a43b65c8

client/MAXUSBApp.py
client/USB.py
client/USBClass.py [new file with mode: 0644]
client/USBConfiguration.py
client/USBDevice.py
client/USBEndpoint.py
client/USBInterface.py
client/USBKeyboard.py
client/USBMassStorage.py [new file with mode: 0644]
client/facedancer-umass.py [new file with mode: 0755]

index 335c597..53b61f7 100644 (file)
@@ -168,6 +168,13 @@ class MAXUSBApp(FacedancerApp):
         else:
             raise ValueError('endpoint ' + str(ep_num) + ' not supported')
 
+        # FIFO buffer is only 64 bytes, must loop
+        while len(data) > 64:
+            self.write_bytes(fifo_reg, data[:64])
+            self.write_register(bc_reg, 64, ack=True)
+
+            data = data[64:]
+
         self.write_bytes(fifo_reg, data)
         self.write_register(bc_reg, len(data), ack=True)
 
@@ -175,9 +182,22 @@ class MAXUSBApp(FacedancerApp):
             print(self.app_name, "wrote", bytes_as_hex(data), "to endpoint",
                     ep_num)
 
-    # TODO
+    # HACK: but given the limitations of the MAX chips, it seems necessary
     def read_from_endpoint(self, ep_num):
-        pass
+        if ep_num != 1:
+            return b''
+
+        byte_count = self.read_register(self.reg_ep1_out_byte_count)
+        if byte_count == 0:
+            return b''
+
+        data = self.read_bytes(self.reg_ep1_out_fifo, byte_count)
+
+        if self.verbose > 1:
+            print(self.app_name, "read", bytes_as_hex(data), "from endpoint",
+                    ep_num)
+
+        return data
 
     def stall_ep0(self):
         if self.verbose > 0:
@@ -205,8 +225,9 @@ class MAXUSBApp(FacedancerApp):
                 self.connected_device.handle_request(req)
 
             if irq & self.is_out1_data_avail:
-                data = b'' # TODO: read from out1
-                self.connected_device.handle_data_available(1, data)
+                data = self.read_from_endpoint(1)
+                if data:
+                    self.connected_device.handle_data_available(1, data)
                 self.clear_irq_bit(self.reg_endpoint_irq, self.is_out1_data_avail)
 
             if irq & self.is_in2_buffer_avail:
index 552b036..3261520 100644 (file)
@@ -49,5 +49,5 @@ class USB:
     }
 
     def interface_class_to_descriptor_type(interface_class):
-        return USB.if_class_to_desc_type[interface_class]
+        return USB.if_class_to_desc_type.get(interface_class, None)
 
diff --git a/client/USBClass.py b/client/USBClass.py
new file mode 100644 (file)
index 0000000..daac236
--- /dev/null
@@ -0,0 +1,35 @@
+# USBClass.py
+#
+# Contains class definition for USBClass, intended as a base class (in the OO
+# sense) for implementing device classes (in the USB sense), eg, HID devices,
+# mass storage devices.
+
+class USBClass:
+    name = "generic USB device class"
+
+    # maps bRequest to handler function
+    request_handlers = { }
+
+    def __init__(self, verbose=0):
+        self.interface = None
+        self.verbose = verbose
+
+        self.setup_request_handlers()
+
+    def set_interface(self, interface):
+        self.interface = interface
+
+    def setup_request_handlers(self):
+        """To be overridden for subclasses to modify self.class_request_handlers"""
+        pass
+
+    def handle_class_request(self, req):
+        r = req.get_request()
+
+        if r in self.class_request_handlers:
+            self.class_request_handlers[r](req)
+        else:
+            print(self.name, "unhandled class request:", req)
+            print(self.name, "stalling")
+            self.device.stall_ep0()
+
index 98a3c99..aeed133 100644 (file)
@@ -12,6 +12,14 @@ class USBConfiguration:
         self.attributes = 0xe0
         self.max_power = 0x01
 
+        self.device = None
+
+        for i in self.interfaces:
+            i.set_configuration(self)
+
+    def set_device(self, device):
+        self.device = device
+
     def set_configuration_string_index(self, i):
         self.configuration_string_index = i
 
index e25fea8..49ba0a4 100644 (file)
@@ -3,6 +3,7 @@
 # Contains class definitions for USBDevice and USBDeviceRequest.
 
 from USB import *
+from USBClass import *
 
 class USBDevice:
     name = "generic device"
@@ -42,9 +43,7 @@ class USBDevice:
         for c in self.configurations:
             csi = self.get_string_id(c.configuration_string)
             c.set_configuration_string_index(csi)
-
-            for i in c.interfaces:
-                i.device = self
+            c.set_device(self)
 
         self.state = USB.state_detached
         self.ready = False
@@ -128,38 +127,57 @@ class USBDevice:
             print(self.name, "received request", req)
 
         # figure out the intended recipient
-        if req.get_recipient() == USB.request_recipient_device:
+        recipient_type = req.get_recipient()
+        recipient = None
+        index = req.get_index()
+        if recipient_type == USB.request_recipient_device:
             recipient = self
-        elif req.get_recipient() == USB.request_recipient_interface:
-            recipient = self.configuration.interfaces[req.index]
-        elif req.get_recipient() == USB.request_recipient_endpoint:
-            try:
-                recipient = self.endpoints[req.index]
-            except KeyError:
-                self.maxusb_app.stall_ep0()
-                return
+        elif recipient_type == USB.request_recipient_interface:
+            if index < len(self.configuration.interfaces):
+                recipient = self.configuration.interfaces[index]
+        elif recipient_type == USB.request_recipient_endpoint:
+            recipient = self.endpoints.get(index, None)
+
+        if not recipient:
+            print(self.name, "invalid recipient, stalling")
+            self.maxusb_app.stall_ep0()
+            return
 
         # and then the type
-        if req.get_type() == USB.request_type_standard:
-            handler = recipient.request_handlers[req.request]
-            handler(req)
-        elif req.get_type() == USB.request_type_class:
-            # HACK: evidently, FreeBSD doesn't pay attention to the device
-            # until it sends a GET_STATUS(class) message
-            self.ready = True
+        req_type = req.get_type()
+        handler_entity = None
+        if req_type == USB.request_type_standard:
+            handler_entity = recipient
+        elif req_type == USB.request_type_class:
+            handler_entity = recipient.device_class
+        elif req_type == USB.request_type_vendor:
+            handler_entity = recipient.get_device_vendor()
+
+        if not handler_entity:
+            print(self.name, "invalid handler entity, stalling")
             self.maxusb_app.stall_ep0()
-        elif req.get_type() == USB.request_type_vendor:
+            return
+
+        handler = handler_entity.request_handlers.get(req.request, None)
+
+        if not handler:
+            print(self.name, "invalid handler, stalling")
             self.maxusb_app.stall_ep0()
+            return
+
+        handler(req)
 
     def handle_data_available(self, ep_num, data):
         if self.state == USB.state_configured and ep_num in self.endpoints:
             endpoint = self.endpoints[ep_num]
-            endpoint.handler(data)
+            if callable(endpoint.handler):
+                endpoint.handler(data)
 
     def handle_buffer_available(self, ep_num):
         if self.state == USB.state_configured and ep_num in self.endpoints:
             endpoint = self.endpoints[ep_num]
-            endpoint.handler()
+            if callable(endpoint.handler):
+                endpoint.handler()
     
     # standard request handlers
     #####################################################
@@ -327,3 +345,12 @@ class USBDeviceRequest:
     def get_recipient(self):
         return self.request_type & 0x1f
 
+    # meaning of bits in wIndex changes whether we're talking about an
+    # interface or an endpoint (see USB 2.0 spec section 9.3.4)
+    def get_index(self):
+        rec = self.get_recipient()
+        if rec == 1:                # interface
+            return self.index
+        elif rec == 2:              # endpoint
+            return self.index & 0x0f
+
index a0d1c57..352d6a1 100644 (file)
@@ -32,6 +32,20 @@ class USBEndpoint:
         self.interval           = interval
         self.handler            = handler
 
+        self.interface          = None
+
+        self.request_handlers   = {
+                1 : self.handle_clear_feature_request
+        }
+
+    def handle_clear_feature_request(self, req):
+        print("received CLEAR_FEATURE request for endpoint", self.number,
+                "with value", req.value)
+        self.interface.configuration.device.maxusb_app.send_on_endpoint(0, b'')
+
+    def set_interface(self, interface):
+        self.interface = interface
+
     # see Table 9-13 of USB 2.0 spec (pdf page 297)
     def get_descriptor(self):
         address = (self.number & 0x0f) | (self.direction << 7) 
index 449b557..c3b1bfe 100644 (file)
@@ -30,6 +30,16 @@ class USBInterface:
             11 : self.handle_set_interface_request
         }
 
+        self.configuration = None
+
+        for e in self.endpoints:
+            e.set_interface(self)
+
+        self.device_class = None
+
+    def set_configuration(self, config):
+        self.configuration = config
+
     # USB 2.0 specification, section 9.4.3 (p 281 of pdf)
     # HACK: blatant copypasta from USBDevice pains me deeply
     def handle_get_descriptor_request(self, req):
@@ -52,13 +62,13 @@ class USBInterface:
 
         if response:
             n = min(n, len(response))
-            self.device.maxusb_app.send_on_endpoint(0, response[:n])
+            self.configuration.device.maxusb_app.send_on_endpoint(0, response[:n])
 
             if self.verbose > 5:
                 print(self.name, "sent", n, "bytes in response")
 
     def handle_set_interface_request(self, req):
-        self.device.maxusb_app.stall_ep0()
+        self.configuration.device.maxusb_app.stall_ep0()
 
     # Table 9-12 of USB 2.0 spec (pdf page 296)
     def get_descriptor(self):
@@ -77,7 +87,8 @@ class USBInterface:
 
         if self.iclass:
             iclass_desc_num = USB.interface_class_to_descriptor_type(self.iclass)
-            d += self.descriptors[iclass_desc_num]
+            if iclass_desc_num:
+                d += self.descriptors[iclass_desc_num]
 
         for e in self.endpoints:
             d += e.get_descriptor()
index 4dee1f9..2d5f99e 100644 (file)
@@ -64,7 +64,7 @@ class USBKeyboardInterface(USBInterface):
         if self.verbose > 2:
             print(self.name, "sending keypress 0x%02x" % ord(letter))
 
-        self.device.maxusb_app.send_on_endpoint(3, data)
+        self.configuration.device.maxusb_app.send_on_endpoint(3, data)
 
 
 class USBKeyboardDevice(USBDevice):
diff --git a/client/USBMassStorage.py b/client/USBMassStorage.py
new file mode 100644 (file)
index 0000000..22cfefc
--- /dev/null
@@ -0,0 +1,347 @@
+# USBMassStorage.py 
+#
+# Contains class definitions to implement a USB mass storage device.
+
+from mmap import mmap
+import os
+
+from USB import *
+from USBDevice import *
+from USBConfiguration import *
+from USBInterface import *
+from USBEndpoint import *
+from USBClass import *
+
+from util import *
+
+class USBMassStorageClass(USBClass):
+    name = "USB mass storage class"
+
+    def setup_request_handlers(self):
+        self.request_handlers = {
+            0xFF : self.handle_bulk_only_mass_storage_reset_request,
+            0xFE : self.handle_get_max_lun_request
+        }
+
+    def handle_bulk_only_mass_storage_reset_request(self, req):
+        self.interface.configuration.device.maxusb_app.send_on_endpoint(0, b'')
+
+    def handle_get_max_lun_request(self, req):
+        self.interface.configuration.device.maxusb_app.send_on_endpoint(0, b'\x00')
+
+
+class USBMassStorageInterface(USBInterface):
+    name = "USB mass storage interface"
+
+    def __init__(self, disk_image, verbose=0):
+        self.disk_image = disk_image
+        descriptors = { }
+
+        endpoints = [
+            USBEndpoint(
+                1,          # endpoint number
+                USBEndpoint.direction_out,
+                USBEndpoint.transfer_type_bulk,
+                USBEndpoint.sync_type_none,
+                USBEndpoint.usage_type_data,
+                16384,      # max packet size
+                0,          # polling interval, see USB 2.0 spec Table 9-13
+                self.handle_data_available    # handler function
+            ),
+            USBEndpoint(
+                3,          # endpoint number
+                USBEndpoint.direction_in,
+                USBEndpoint.transfer_type_bulk,
+                USBEndpoint.sync_type_none,
+                USBEndpoint.usage_type_data,
+                16384,      # max packet size
+                0,          # polling interval, see USB 2.0 spec Table 9-13
+                None        # handler function
+            )
+        ]
+
+        # TODO: un-hardcode string index (last arg before "verbose")
+        USBInterface.__init__(
+                self,
+                0,          # interface number
+                0,          # alternate setting
+                8,          # interface class: Mass Storage
+                6,          # subclass: SCSI transparent command set
+                0x50,       # protocol: bulk-only (BBB) transport
+                0,          # string index
+                verbose,
+                endpoints,
+                descriptors
+        )
+
+        self.device_class = USBMassStorageClass()
+        self.device_class.set_interface(self)
+
+        self.is_write_in_progress = False
+        self.write_cbw = None
+        self.write_base_lba = 0
+        self.write_length = 0
+        self.write_data = b''
+
+    def handle_data_available(self, data):
+        print(self.name, "handling", len(data), "bytes of SCSI data")
+
+        cbw = CommandBlockWrapper(data)
+        opcode = cbw.cb[0]
+
+        status = 0              # default to success
+        response = None         # with no response data
+
+        if self.is_write_in_progress:
+            if self.verbose > 0:
+                print(self.name, "got", len(data), "bytes of SCSI write data")
+
+            self.write_data += data
+
+            if len(self.write_data) < self.write_length:
+                # more yet to read, don't send the CSW
+                return
+
+            self.disk_image.put_sector_data(self.write_base_lba, self.write_data)
+            cbw = self.write_cbw
+
+            self.is_write_in_progress = False
+            self.write_data = b''
+
+        elif opcode == 0x00:      # Test Unit Ready: just return OK status
+            if self.verbose > 0:
+                print(self.name, "got SCSI Test Unit Ready")
+
+        elif opcode == 0x03:    # Request Sense
+            if self.verbose > 0:
+                print(self.name, "got SCSI Request Sense, data",
+                        bytes_as_hex(cbw.cb[1:]))
+
+            response = b'\x70\x00\xFF\x00\x00\x00\x00\x0A\x00\x00\x00\x00\xFF\xFF\x00\x00\x00\x00\x00\x00\x00\x00\x00'
+
+        elif opcode == 0x12:    # Inquiry
+            if self.verbose > 0:
+                print(self.name, "got SCSI Inquiry, data",
+                        bytes_as_hex(cbw.cb[1:]))
+
+            response = bytes([
+                0x00,       # 00 for Direct, 1F for "no floppy"
+                0x00,       # make 0x80 for removable media, 0x00 for fixed
+                0x00,       # Version
+                0x01,       # Response Data Format
+                0x14,       # Additional length.
+                0x00, 0x00, 0x00
+            ])
+
+            response += b'GoodFET '         # vendor
+            response += b'GoodFET '         # product id
+            response += b'        '         # product revision
+            response += b'0.01'
+
+            # pad up to data_transfer_length bytes
+            #diff = cbw.data_transfer_length - len(response)
+            #response += bytes([0] * diff)
+
+        elif opcode == 0x1e:    # Prevent/Allow Removal: feign success
+            if self.verbose > 0:
+                print(self.name, "got SCSI Prevent/Allow Removal")
+
+        #elif opcode == 0x1a or opcode == 0x5a:      # Mode Sense (6 or 10)
+            # TODO
+
+        elif opcode == 0x23:    # Read Format Capacity
+            if self.verbose > 0:
+                print(self.name, "got SCSI Read Format Capacity")
+
+            response = bytes([
+                0x00, 0x00, 0x00, 0x08,     # capacity list length
+                0x00, 0x00, 0x10, 0x00,     # number of sectors (0x1000 = 10MB)
+                0x10, 0x00,                 # reserved/descriptor code
+                0x02, 0x00,                 # 512-byte sectors
+            ])
+
+        elif opcode == 0x25:    # Read Capacity
+            if self.verbose > 0:
+                print(self.name, "got SCSI Read Capacity, data",
+                        bytes_as_hex(cbw.cb[1:]))
+
+            lastlba = self.disk_image.get_sector_count()
+
+            response = bytes([
+                (lastlba >> 24) & 0xff,
+                (lastlba >> 16) & 0xff,
+                (lastlba >>  8) & 0xff,
+                (lastlba      ) & 0xff,
+                0x00, 0x00, 0x02, 0x00,     # 512-byte blocks
+            ])
+
+        elif opcode == 0x28:    # Read (10)
+            base_lba = cbw.cb[2] << 24 \
+                     | cbw.cb[3] << 16 \
+                     | cbw.cb[4] << 8 \
+                     | cbw.cb[5]
+
+            num_blocks = cbw.cb[7] << 8 \
+                       | cbw.cb[8]
+
+            if self.verbose > 0:
+                print(self.name, "got SCSI Read (10), lba", base_lba, "+",
+                        num_blocks, "block(s)")
+                        
+
+            # Note that here we send the data directly rather than putting
+            # something in 'response' and letting the end of the switch send
+            for block_num in range(num_blocks):
+                data = self.disk_image.get_sector_data(base_lba + block_num)
+                self.configuration.device.maxusb_app.send_on_endpoint(3, data)
+
+        elif opcode == 0x2a:    # Write (10)
+            if self.verbose > 0:
+                print(self.name, "got SCSI Write (10), data",
+                        bytes_as_hex(cbw.cb[1:]))
+
+            base_lba = cbw.cb[1] << 24 \
+                     | cbw.cb[2] << 16 \
+                     | cbw.cb[3] <<  8 \
+                     | cbw.cb[4]
+
+            num_blocks = cbw.cb[7] << 8 \
+                       | cbw.cb[8]
+
+            if self.verbose > 0:
+                print(self.name, "got SCSI Write (10), lba", base_lba, "+",
+                        num_blocks, "block(s)")
+
+            # save for later
+            self.write_cbw = cbw
+            self.write_base_lba = base_lba
+            self.write_length = num_blocks * self.disk_image.block_size
+            self.is_write_in_progress = True
+
+            # because we need to snarf up the data from wire before we reply
+            # with the CSW
+            return
+
+        elif opcode == 0x35:    # Synchronize Cache (10): blindly OK
+            if self.verbose > 0:
+                print(self.name, "got Synchronize Cache (10)")
+
+        else:
+            print(self.name, "received unsupported SCSI opcode 0x%x" % opcode)
+            status = 0x02   # command failed
+            if cbw.data_transfer_length > 0:
+                response = bytes([0] * cbw.data_transfer_length)
+
+        if response:
+            if self.verbose > 2:
+                print(self.name, "responding with", len(response), "bytes:",
+                        bytes_as_hex(response))
+
+            self.configuration.device.maxusb_app.send_on_endpoint(3, response)
+
+        csw = bytes([
+            ord('U'), ord('S'), ord('B'), ord('S'),
+            cbw.tag[0], cbw.tag[1], cbw.tag[2], cbw.tag[3],
+            0x00, 0x00, 0x00, 0x00,
+            status
+        ])
+
+        if self.verbose > 3:
+            print(self.name, "responding with status =", status)
+
+        self.configuration.device.maxusb_app.send_on_endpoint(3, csw)
+
+
+class DiskImage:
+    def __init__(self, filename, block_size):
+        self.filename = filename
+        self.block_size = block_size
+
+        statinfo = os.stat(self.filename)
+        self.size = statinfo.st_size
+
+        self.file = open(self.filename, 'r+b')
+        self.image = mmap(self.file.fileno(), 0)
+
+    def close(self):
+        self.image.flush()
+        self.image.close()
+
+    def get_sector_count(self):
+        return int(self.size / self.block_size) - 1
+
+    def get_sector_data(self, address):
+        block_start = address * self.block_size
+        block_end   = (address + 1) * self.block_size   # slices are NON-inclusive
+
+        return self.image[block_start:block_end]
+
+    def put_sector_data(self, address, data):
+        block_start = address * self.block_size
+        block_end   = (address + 1) * self.block_size   # slices are NON-inclusive
+
+        self.image[block_start:block_end] = data[:self.block_size]
+        self.image.flush()
+
+
+class CommandBlockWrapper:
+    def __init__(self, bytestring):
+        self.signature              = bytestring[0:4]
+        self.tag                    = bytestring[4:8]
+        self.data_transfer_length   = bytestring[8] \
+                                    | bytestring[9] << 8 \
+                                    | bytestring[10] << 16 \
+                                    | bytestring[11] << 24
+        self.flags                  = int(bytestring[12])
+        self.lun                    = int(bytestring[13] & 0x0f)
+        self.cb_length              = int(bytestring[14] & 0x1f)
+        #self.cb                     = bytestring[15:15+self.cb_length]
+        self.cb                     = bytestring[15:]
+
+    def __str__(self):
+        s  = "sig: " + bytes_as_hex(self.signature) + "\n"
+        s += "tag: " + bytes_as_hex(self.tag) + "\n"
+        s += "data transfer len: " + str(self.data_transfer_length) + "\n"
+        s += "flags: " + str(self.flags) + "\n"
+        s += "lun: " + str(self.lun) + "\n"
+        s += "command block len: " + str(self.cb_length) + "\n"
+        s += "command block: " + bytes_as_hex(self.cb) + "\n"
+
+        return s
+
+
+class USBMassStorageDevice(USBDevice):
+    name = "USB mass storage device"
+
+    def __init__(self, maxusb_app, disk_image_filename, verbose=0):
+        self.disk_image = DiskImage(disk_image_filename, 512)
+
+        interface = USBMassStorageInterface(self.disk_image, verbose=verbose)
+
+        config = USBConfiguration(
+                1,                                          # index
+                "Maxim umass config",                       # string desc
+                [ interface ]                               # interfaces
+        )
+
+        USBDevice.__init__(
+                self,
+                maxusb_app,
+                0,                      # device class
+                0,                      # device subclass
+                0,                      # protocol release number
+                64,                     # max packet size for endpoint 0
+                0x8107,                 # vendor id: Sandisk
+                0x5051,                 # product id: SDCZ2 Cruzer Mini Flash Drive (thin)
+                0x0003,                 # device revision
+                "Maxim",                # manufacturer string
+                "MAX3420E Enum Code",   # product string
+                "S/N3420E",             # serial number string
+                [ config ],
+                verbose=verbose
+        )
+
+    def disconnect(self):
+        self.disk_image.close()
+        USBDevice.disconnect(self)
+
diff --git a/client/facedancer-umass.py b/client/facedancer-umass.py
new file mode 100755 (executable)
index 0000000..102ab5a
--- /dev/null
@@ -0,0 +1,36 @@
+#!/usr/bin/env python3
+#
+# facedancer-umass.py
+#
+# Creating a disk image under linux:
+#
+#   # fallocate -l 100M disk.img
+#   # fdisk disk.img
+#   # losetup -f --show disk.img
+#   # kpartx -a /dev/loopX
+#   # mkfs.XXX /dev/mapper/loopXpY
+#   # mount /dev/mapper/loopXpY /mnt/point
+#       do stuff on /mnt/point
+#   # umount /mnt/point
+#   # kpartx -d /dev/loopX
+#   # losetup -d /dev/loopX
+
+from serial import Serial, PARITY_NONE
+
+from Facedancer import *
+from MAXUSBApp import *
+from USBMassStorage import *
+
+sp = Serial("/dev/ttyUSB0", 115200, parity=PARITY_NONE, timeout=2)
+fd = Facedancer(sp, verbose=1)
+u = MAXUSBApp(fd, verbose=1)
+
+d = USBMassStorageDevice(u, "test.img", verbose=3)
+
+d.connect()
+
+try:
+    d.run()
+except KeyboardInterrupt:
+    d.disconnect()
+