massive commit to get umass working, only significant known bug is that changes don...
[goodfet] / client / USBMassStorage.py
1 # USBMassStorage.py 
2 #
3 # Contains class definitions to implement a USB mass storage device.
4
5 from mmap import mmap
6 import os
7
8 from USB import *
9 from USBDevice import *
10 from USBConfiguration import *
11 from USBInterface import *
12 from USBEndpoint import *
13 from USBClass import *
14
15 from util import *
16
17 class USBMassStorageClass(USBClass):
18     name = "USB mass storage class"
19
20     def setup_request_handlers(self):
21         self.request_handlers = {
22             0xFF : self.handle_bulk_only_mass_storage_reset_request,
23             0xFE : self.handle_get_max_lun_request
24         }
25
26     def handle_bulk_only_mass_storage_reset_request(self, req):
27         self.interface.configuration.device.maxusb_app.send_on_endpoint(0, b'')
28
29     def handle_get_max_lun_request(self, req):
30         self.interface.configuration.device.maxusb_app.send_on_endpoint(0, b'\x00')
31
32
33 class USBMassStorageInterface(USBInterface):
34     name = "USB mass storage interface"
35
36     def __init__(self, disk_image, verbose=0):
37         self.disk_image = disk_image
38         descriptors = { }
39
40         endpoints = [
41             USBEndpoint(
42                 1,          # endpoint number
43                 USBEndpoint.direction_out,
44                 USBEndpoint.transfer_type_bulk,
45                 USBEndpoint.sync_type_none,
46                 USBEndpoint.usage_type_data,
47                 16384,      # max packet size
48                 0,          # polling interval, see USB 2.0 spec Table 9-13
49                 self.handle_data_available    # handler function
50             ),
51             USBEndpoint(
52                 3,          # endpoint number
53                 USBEndpoint.direction_in,
54                 USBEndpoint.transfer_type_bulk,
55                 USBEndpoint.sync_type_none,
56                 USBEndpoint.usage_type_data,
57                 16384,      # max packet size
58                 0,          # polling interval, see USB 2.0 spec Table 9-13
59                 None        # handler function
60             )
61         ]
62
63         # TODO: un-hardcode string index (last arg before "verbose")
64         USBInterface.__init__(
65                 self,
66                 0,          # interface number
67                 0,          # alternate setting
68                 8,          # interface class: Mass Storage
69                 6,          # subclass: SCSI transparent command set
70                 0x50,       # protocol: bulk-only (BBB) transport
71                 0,          # string index
72                 verbose,
73                 endpoints,
74                 descriptors
75         )
76
77         self.device_class = USBMassStorageClass()
78         self.device_class.set_interface(self)
79
80         self.is_write_in_progress = False
81         self.write_cbw = None
82         self.write_base_lba = 0
83         self.write_length = 0
84         self.write_data = b''
85
86     def handle_data_available(self, data):
87         print(self.name, "handling", len(data), "bytes of SCSI data")
88
89         cbw = CommandBlockWrapper(data)
90         opcode = cbw.cb[0]
91
92         status = 0              # default to success
93         response = None         # with no response data
94
95         if self.is_write_in_progress:
96             if self.verbose > 0:
97                 print(self.name, "got", len(data), "bytes of SCSI write data")
98
99             self.write_data += data
100
101             if len(self.write_data) < self.write_length:
102                 # more yet to read, don't send the CSW
103                 return
104
105             self.disk_image.put_sector_data(self.write_base_lba, self.write_data)
106             cbw = self.write_cbw
107
108             self.is_write_in_progress = False
109             self.write_data = b''
110
111         elif opcode == 0x00:      # Test Unit Ready: just return OK status
112             if self.verbose > 0:
113                 print(self.name, "got SCSI Test Unit Ready")
114
115         elif opcode == 0x03:    # Request Sense
116             if self.verbose > 0:
117                 print(self.name, "got SCSI Request Sense, data",
118                         bytes_as_hex(cbw.cb[1:]))
119
120             response = b'\x70\x00\xFF\x00\x00\x00\x00\x0A\x00\x00\x00\x00\xFF\xFF\x00\x00\x00\x00\x00\x00\x00\x00\x00'
121
122         elif opcode == 0x12:    # Inquiry
123             if self.verbose > 0:
124                 print(self.name, "got SCSI Inquiry, data",
125                         bytes_as_hex(cbw.cb[1:]))
126
127             response = bytes([
128                 0x00,       # 00 for Direct, 1F for "no floppy"
129                 0x00,       # make 0x80 for removable media, 0x00 for fixed
130                 0x00,       # Version
131                 0x01,       # Response Data Format
132                 0x14,       # Additional length.
133                 0x00, 0x00, 0x00
134             ])
135
136             response += b'GoodFET '         # vendor
137             response += b'GoodFET '         # product id
138             response += b'        '         # product revision
139             response += b'0.01'
140
141             # pad up to data_transfer_length bytes
142             #diff = cbw.data_transfer_length - len(response)
143             #response += bytes([0] * diff)
144
145         elif opcode == 0x1e:    # Prevent/Allow Removal: feign success
146             if self.verbose > 0:
147                 print(self.name, "got SCSI Prevent/Allow Removal")
148
149         #elif opcode == 0x1a or opcode == 0x5a:      # Mode Sense (6 or 10)
150             # TODO
151
152         elif opcode == 0x23:    # Read Format Capacity
153             if self.verbose > 0:
154                 print(self.name, "got SCSI Read Format Capacity")
155
156             response = bytes([
157                 0x00, 0x00, 0x00, 0x08,     # capacity list length
158                 0x00, 0x00, 0x10, 0x00,     # number of sectors (0x1000 = 10MB)
159                 0x10, 0x00,                 # reserved/descriptor code
160                 0x02, 0x00,                 # 512-byte sectors
161             ])
162
163         elif opcode == 0x25:    # Read Capacity
164             if self.verbose > 0:
165                 print(self.name, "got SCSI Read Capacity, data",
166                         bytes_as_hex(cbw.cb[1:]))
167
168             lastlba = self.disk_image.get_sector_count()
169
170             response = bytes([
171                 (lastlba >> 24) & 0xff,
172                 (lastlba >> 16) & 0xff,
173                 (lastlba >>  8) & 0xff,
174                 (lastlba      ) & 0xff,
175                 0x00, 0x00, 0x02, 0x00,     # 512-byte blocks
176             ])
177
178         elif opcode == 0x28:    # Read (10)
179             base_lba = cbw.cb[2] << 24 \
180                      | cbw.cb[3] << 16 \
181                      | cbw.cb[4] << 8 \
182                      | cbw.cb[5]
183
184             num_blocks = cbw.cb[7] << 8 \
185                        | cbw.cb[8]
186
187             if self.verbose > 0:
188                 print(self.name, "got SCSI Read (10), lba", base_lba, "+",
189                         num_blocks, "block(s)")
190                         
191
192             # Note that here we send the data directly rather than putting
193             # something in 'response' and letting the end of the switch send
194             for block_num in range(num_blocks):
195                 data = self.disk_image.get_sector_data(base_lba + block_num)
196                 self.configuration.device.maxusb_app.send_on_endpoint(3, data)
197
198         elif opcode == 0x2a:    # Write (10)
199             if self.verbose > 0:
200                 print(self.name, "got SCSI Write (10), data",
201                         bytes_as_hex(cbw.cb[1:]))
202
203             base_lba = cbw.cb[1] << 24 \
204                      | cbw.cb[2] << 16 \
205                      | cbw.cb[3] <<  8 \
206                      | cbw.cb[4]
207
208             num_blocks = cbw.cb[7] << 8 \
209                        | cbw.cb[8]
210
211             if self.verbose > 0:
212                 print(self.name, "got SCSI Write (10), lba", base_lba, "+",
213                         num_blocks, "block(s)")
214
215             # save for later
216             self.write_cbw = cbw
217             self.write_base_lba = base_lba
218             self.write_length = num_blocks * self.disk_image.block_size
219             self.is_write_in_progress = True
220
221             # because we need to snarf up the data from wire before we reply
222             # with the CSW
223             return
224
225         elif opcode == 0x35:    # Synchronize Cache (10): blindly OK
226             if self.verbose > 0:
227                 print(self.name, "got Synchronize Cache (10)")
228
229         else:
230             print(self.name, "received unsupported SCSI opcode 0x%x" % opcode)
231             status = 0x02   # command failed
232             if cbw.data_transfer_length > 0:
233                 response = bytes([0] * cbw.data_transfer_length)
234
235         if response:
236             if self.verbose > 2:
237                 print(self.name, "responding with", len(response), "bytes:",
238                         bytes_as_hex(response))
239
240             self.configuration.device.maxusb_app.send_on_endpoint(3, response)
241
242         csw = bytes([
243             ord('U'), ord('S'), ord('B'), ord('S'),
244             cbw.tag[0], cbw.tag[1], cbw.tag[2], cbw.tag[3],
245             0x00, 0x00, 0x00, 0x00,
246             status
247         ])
248
249         if self.verbose > 3:
250             print(self.name, "responding with status =", status)
251
252         self.configuration.device.maxusb_app.send_on_endpoint(3, csw)
253
254
255 class DiskImage:
256     def __init__(self, filename, block_size):
257         self.filename = filename
258         self.block_size = block_size
259
260         statinfo = os.stat(self.filename)
261         self.size = statinfo.st_size
262
263         self.file = open(self.filename, 'r+b')
264         self.image = mmap(self.file.fileno(), 0)
265
266     def close(self):
267         self.image.flush()
268         self.image.close()
269
270     def get_sector_count(self):
271         return int(self.size / self.block_size) - 1
272
273     def get_sector_data(self, address):
274         block_start = address * self.block_size
275         block_end   = (address + 1) * self.block_size   # slices are NON-inclusive
276
277         return self.image[block_start:block_end]
278
279     def put_sector_data(self, address, data):
280         block_start = address * self.block_size
281         block_end   = (address + 1) * self.block_size   # slices are NON-inclusive
282
283         self.image[block_start:block_end] = data[:self.block_size]
284         self.image.flush()
285
286
287 class CommandBlockWrapper:
288     def __init__(self, bytestring):
289         self.signature              = bytestring[0:4]
290         self.tag                    = bytestring[4:8]
291         self.data_transfer_length   = bytestring[8] \
292                                     | bytestring[9] << 8 \
293                                     | bytestring[10] << 16 \
294                                     | bytestring[11] << 24
295         self.flags                  = int(bytestring[12])
296         self.lun                    = int(bytestring[13] & 0x0f)
297         self.cb_length              = int(bytestring[14] & 0x1f)
298         #self.cb                     = bytestring[15:15+self.cb_length]
299         self.cb                     = bytestring[15:]
300
301     def __str__(self):
302         s  = "sig: " + bytes_as_hex(self.signature) + "\n"
303         s += "tag: " + bytes_as_hex(self.tag) + "\n"
304         s += "data transfer len: " + str(self.data_transfer_length) + "\n"
305         s += "flags: " + str(self.flags) + "\n"
306         s += "lun: " + str(self.lun) + "\n"
307         s += "command block len: " + str(self.cb_length) + "\n"
308         s += "command block: " + bytes_as_hex(self.cb) + "\n"
309
310         return s
311
312
313 class USBMassStorageDevice(USBDevice):
314     name = "USB mass storage device"
315
316     def __init__(self, maxusb_app, disk_image_filename, verbose=0):
317         self.disk_image = DiskImage(disk_image_filename, 512)
318
319         interface = USBMassStorageInterface(self.disk_image, verbose=verbose)
320
321         config = USBConfiguration(
322                 1,                                          # index
323                 "Maxim umass config",                       # string desc
324                 [ interface ]                               # interfaces
325         )
326
327         USBDevice.__init__(
328                 self,
329                 maxusb_app,
330                 0,                      # device class
331                 0,                      # device subclass
332                 0,                      # protocol release number
333                 64,                     # max packet size for endpoint 0
334                 0x8107,                 # vendor id: Sandisk
335                 0x5051,                 # product id: SDCZ2 Cruzer Mini Flash Drive (thin)
336                 0x0003,                 # device revision
337                 "Maxim",                # manufacturer string
338                 "MAX3420E Enum Code",   # product string
339                 "S/N3420E",             # serial number string
340                 [ config ],
341                 verbose=verbose
342         )
343
344     def disconnect(self):
345         self.disk_image.close()
346         USBDevice.disconnect(self)
347