turn ftdi driver into echo server
[goodfet] / client / USBMassStorage.py
1 # USBMassStorage.py 
2 #
3 # Contains class definitions to implement a USB mass storage device.
4
5 from mmap import mmap
6 import os
7
8 from USB import *
9 from USBDevice import *
10 from USBConfiguration import *
11 from USBInterface import *
12 from USBEndpoint import *
13 from USBClass import *
14
15 from util import *
16
17 class USBMassStorageClass(USBClass):
18     name = "USB mass storage class"
19
20     def setup_request_handlers(self):
21         self.request_handlers = {
22             0xFF : self.handle_bulk_only_mass_storage_reset_request,
23             0xFE : self.handle_get_max_lun_request
24         }
25
26     def handle_bulk_only_mass_storage_reset_request(self, req):
27         self.interface.configuration.device.maxusb_app.send_on_endpoint(0, b'')
28
29     def handle_get_max_lun_request(self, req):
30         self.interface.configuration.device.maxusb_app.send_on_endpoint(0, b'\x00')
31
32
33 class USBMassStorageInterface(USBInterface):
34     name = "USB mass storage interface"
35
36     def __init__(self, disk_image, verbose=0):
37         self.disk_image = disk_image
38         descriptors = { }
39
40         endpoints = [
41             USBEndpoint(
42                 1,          # endpoint number
43                 USBEndpoint.direction_out,
44                 USBEndpoint.transfer_type_bulk,
45                 USBEndpoint.sync_type_none,
46                 USBEndpoint.usage_type_data,
47                 16384,      # max packet size
48                 0,          # polling interval, see USB 2.0 spec Table 9-13
49                 self.handle_data_available    # handler function
50             ),
51             USBEndpoint(
52                 3,          # endpoint number
53                 USBEndpoint.direction_in,
54                 USBEndpoint.transfer_type_bulk,
55                 USBEndpoint.sync_type_none,
56                 USBEndpoint.usage_type_data,
57                 16384,      # max packet size
58                 0,          # polling interval, see USB 2.0 spec Table 9-13
59                 None        # handler function
60             )
61         ]
62
63         # TODO: un-hardcode string index (last arg before "verbose")
64         USBInterface.__init__(
65                 self,
66                 0,          # interface number
67                 0,          # alternate setting
68                 8,          # interface class: Mass Storage
69                 6,          # subclass: SCSI transparent command set
70                 0x50,       # protocol: bulk-only (BBB) transport
71                 0,          # string index
72                 verbose,
73                 endpoints,
74                 descriptors
75         )
76
77         self.device_class = USBMassStorageClass()
78         self.device_class.set_interface(self)
79
80         self.is_write_in_progress = False
81         self.write_cbw = None
82         self.write_base_lba = 0
83         self.write_length = 0
84         self.write_data = b''
85
86     def handle_data_available(self, data):
87         print(self.name, "handling", len(data), "bytes of SCSI data")
88
89         cbw = CommandBlockWrapper(data)
90         opcode = cbw.cb[0]
91
92         status = 0              # default to success
93         response = None         # with no response data
94
95         if self.is_write_in_progress:
96             if self.verbose > 0:
97                 print(self.name, "got", len(data), "bytes of SCSI write data")
98
99             self.write_data += data
100
101             if len(self.write_data) < self.write_length:
102                 # more yet to read, don't send the CSW
103                 return
104
105             self.disk_image.put_sector_data(self.write_base_lba, self.write_data)
106             cbw = self.write_cbw
107
108             self.is_write_in_progress = False
109             self.write_data = b''
110
111         elif opcode == 0x00:      # Test Unit Ready: just return OK status
112             if self.verbose > 0:
113                 print(self.name, "got SCSI Test Unit Ready")
114
115         elif opcode == 0x03:    # Request Sense
116             if self.verbose > 0:
117                 print(self.name, "got SCSI Request Sense, data",
118                         bytes_as_hex(cbw.cb[1:]))
119
120             response = b'\x70\x00\xFF\x00\x00\x00\x00\x0A\x00\x00\x00\x00\xFF\xFF\x00\x00\x00\x00\x00\x00\x00\x00\x00'
121
122         elif opcode == 0x12:    # Inquiry
123             if self.verbose > 0:
124                 print(self.name, "got SCSI Inquiry, data",
125                         bytes_as_hex(cbw.cb[1:]))
126
127             response = bytes([
128                 0x00,       # 00 for Direct, 1F for "no floppy"
129                 0x00,       # make 0x80 for removable media, 0x00 for fixed
130                 0x00,       # Version
131                 0x01,       # Response Data Format
132                 0x14,       # Additional length.
133                 0x00, 0x00, 0x00
134             ])
135
136             response += b'GoodFET '         # vendor
137             response += b'GoodFET '         # product id
138             response += b'        '         # product revision
139             response += b'0.01'
140
141             # pad up to data_transfer_length bytes
142             #diff = cbw.data_transfer_length - len(response)
143             #response += bytes([0] * diff)
144
145         elif opcode == 0x1a or opcode == 0x5a:    # Mode Sense (6 or 10)
146             page = cbw.cb[2] & 0x3f
147
148             if self.verbose > 0:
149                 print(self.name, "got SCSI Mode Sense, page code 0x%02x" % page)
150
151             response = b'\x07\x00\x00\x00\x00\x00\x00\x1c'
152             if page != 0x3f:
153                 print(self.name, "unkonwn page, returning empty page")
154                 response = b'\x07\x00\x00\x00\x00\x00\x00\x00'
155
156         elif opcode == 0x1e:    # Prevent/Allow Removal: feign success
157             if self.verbose > 0:
158                 print(self.name, "got SCSI Prevent/Allow Removal")
159
160         #elif opcode == 0x1a or opcode == 0x5a:      # Mode Sense (6 or 10)
161             # TODO
162
163         elif opcode == 0x23:    # Read Format Capacity
164             if self.verbose > 0:
165                 print(self.name, "got SCSI Read Format Capacity")
166
167             response = bytes([
168                 0x00, 0x00, 0x00, 0x08,     # capacity list length
169                 0x00, 0x00, 0x10, 0x00,     # number of sectors (0x1000 = 10MB)
170                 0x10, 0x00,                 # reserved/descriptor code
171                 0x02, 0x00,                 # 512-byte sectors
172             ])
173
174         elif opcode == 0x25:    # Read Capacity
175             if self.verbose > 0:
176                 print(self.name, "got SCSI Read Capacity, data",
177                         bytes_as_hex(cbw.cb[1:]))
178
179             lastlba = self.disk_image.get_sector_count()
180
181             response = bytes([
182                 (lastlba >> 24) & 0xff,
183                 (lastlba >> 16) & 0xff,
184                 (lastlba >>  8) & 0xff,
185                 (lastlba      ) & 0xff,
186                 0x00, 0x00, 0x02, 0x00,     # 512-byte blocks
187             ])
188
189         elif opcode == 0x28:    # Read (10)
190             base_lba = cbw.cb[2] << 24 \
191                      | cbw.cb[3] << 16 \
192                      | cbw.cb[4] << 8 \
193                      | cbw.cb[5]
194
195             num_blocks = cbw.cb[7] << 8 \
196                        | cbw.cb[8]
197
198             if self.verbose > 0:
199                 print(self.name, "got SCSI Read (10), lba", base_lba, "+",
200                         num_blocks, "block(s)")
201                         
202
203             # Note that here we send the data directly rather than putting
204             # something in 'response' and letting the end of the switch send
205             for block_num in range(num_blocks):
206                 data = self.disk_image.get_sector_data(base_lba + block_num)
207                 self.configuration.device.maxusb_app.send_on_endpoint(3, data)
208
209         elif opcode == 0x2a:    # Write (10)
210             if self.verbose > 0:
211                 print(self.name, "got SCSI Write (10), data",
212                         bytes_as_hex(cbw.cb[1:]))
213
214             base_lba = cbw.cb[1] << 24 \
215                      | cbw.cb[2] << 16 \
216                      | cbw.cb[3] <<  8 \
217                      | cbw.cb[4]
218
219             num_blocks = cbw.cb[7] << 8 \
220                        | cbw.cb[8]
221
222             if self.verbose > 0:
223                 print(self.name, "got SCSI Write (10), lba", base_lba, "+",
224                         num_blocks, "block(s)")
225
226             # save for later
227             self.write_cbw = cbw
228             self.write_base_lba = base_lba
229             self.write_length = num_blocks * self.disk_image.block_size
230             self.is_write_in_progress = True
231
232             # because we need to snarf up the data from wire before we reply
233             # with the CSW
234             return
235
236         elif opcode == 0x35:    # Synchronize Cache (10): blindly OK
237             if self.verbose > 0:
238                 print(self.name, "got Synchronize Cache (10)")
239
240         else:
241             print(self.name, "received unsupported SCSI opcode 0x%x" % opcode)
242             status = 0x02   # command failed
243             if cbw.data_transfer_length > 0:
244                 response = bytes([0] * cbw.data_transfer_length)
245
246         if response:
247             if self.verbose > 2:
248                 print(self.name, "responding with", len(response), "bytes:",
249                         bytes_as_hex(response))
250
251             self.configuration.device.maxusb_app.send_on_endpoint(3, response)
252
253         csw = bytes([
254             ord('U'), ord('S'), ord('B'), ord('S'),
255             cbw.tag[0], cbw.tag[1], cbw.tag[2], cbw.tag[3],
256             0x00, 0x00, 0x00, 0x00,
257             status
258         ])
259
260         if self.verbose > 3:
261             print(self.name, "responding with status =", status)
262
263         self.configuration.device.maxusb_app.send_on_endpoint(3, csw)
264
265
266 class DiskImage:
267     def __init__(self, filename, block_size):
268         self.filename = filename
269         self.block_size = block_size
270
271         statinfo = os.stat(self.filename)
272         self.size = statinfo.st_size
273
274         self.file = open(self.filename, 'r+b')
275         self.image = mmap(self.file.fileno(), 0)
276
277     def close(self):
278         self.image.flush()
279         self.image.close()
280
281     def get_sector_count(self):
282         return int(self.size / self.block_size) - 1
283
284     def get_sector_data(self, address):
285         block_start = address * self.block_size
286         block_end   = (address + 1) * self.block_size   # slices are NON-inclusive
287
288         return self.image[block_start:block_end]
289
290     def put_sector_data(self, address, data):
291         block_start = address * self.block_size
292         block_end   = (address + 1) * self.block_size   # slices are NON-inclusive
293
294         self.image[block_start:block_end] = data[:self.block_size]
295         self.image.flush()
296
297
298 class CommandBlockWrapper:
299     def __init__(self, bytestring):
300         self.signature              = bytestring[0:4]
301         self.tag                    = bytestring[4:8]
302         self.data_transfer_length   = bytestring[8] \
303                                     | bytestring[9] << 8 \
304                                     | bytestring[10] << 16 \
305                                     | bytestring[11] << 24
306         self.flags                  = int(bytestring[12])
307         self.lun                    = int(bytestring[13] & 0x0f)
308         self.cb_length              = int(bytestring[14] & 0x1f)
309         #self.cb                     = bytestring[15:15+self.cb_length]
310         self.cb                     = bytestring[15:]
311
312     def __str__(self):
313         s  = "sig: " + bytes_as_hex(self.signature) + "\n"
314         s += "tag: " + bytes_as_hex(self.tag) + "\n"
315         s += "data transfer len: " + str(self.data_transfer_length) + "\n"
316         s += "flags: " + str(self.flags) + "\n"
317         s += "lun: " + str(self.lun) + "\n"
318         s += "command block len: " + str(self.cb_length) + "\n"
319         s += "command block: " + bytes_as_hex(self.cb) + "\n"
320
321         return s
322
323
324 class USBMassStorageDevice(USBDevice):
325     name = "USB mass storage device"
326
327     def __init__(self, maxusb_app, disk_image_filename, verbose=0):
328         self.disk_image = DiskImage(disk_image_filename, 512)
329
330         interface = USBMassStorageInterface(self.disk_image, verbose=verbose)
331
332         config = USBConfiguration(
333                 1,                                          # index
334                 "Maxim umass config",                       # string desc
335                 [ interface ]                               # interfaces
336         )
337
338         USBDevice.__init__(
339                 self,
340                 maxusb_app,
341                 0,                      # device class
342                 0,                      # device subclass
343                 0,                      # protocol release number
344                 64,                     # max packet size for endpoint 0
345                 0x8107,                 # vendor id: Sandisk
346                 0x5051,                 # product id: SDCZ2 Cruzer Mini Flash Drive (thin)
347                 0x0003,                 # device revision
348                 "Maxim",                # manufacturer string
349                 "MAX3420E Enum Code",   # product string
350                 "S/N3420E",             # serial number string
351                 [ config ],
352                 verbose=verbose
353         )
354
355     def disconnect(self):
356         self.disk_image.close()
357         USBDevice.disconnect(self)
358