stall when a descriptor is requested for which we have no data
[goodfet] / client / USBDevice.py
index 84e8b1f..4a2ba13 100644 (file)
@@ -3,6 +3,7 @@
 # Contains class definitions for USBDevice and USBDeviceRequest.
 
 from USB import *
 # Contains class definitions for USBDevice and USBDeviceRequest.
 
 from USB import *
+from USBClass import *
 
 class USBDevice:
     name = "generic device"
 
 class USBDevice:
     name = "generic device"
@@ -42,9 +43,7 @@ class USBDevice:
         for c in self.configurations:
             csi = self.get_string_id(c.configuration_string)
             c.set_configuration_string_index(csi)
         for c in self.configurations:
             csi = self.get_string_id(c.configuration_string)
             c.set_configuration_string_index(csi)
-
-            for i in c.interfaces:
-                i.device = self
+            c.set_device(self)
 
         self.state = USB.state_detached
         self.ready = False
 
         self.state = USB.state_detached
         self.ready = False
@@ -128,34 +127,57 @@ class USBDevice:
             print(self.name, "received request", req)
 
         # figure out the intended recipient
             print(self.name, "received request", req)
 
         # figure out the intended recipient
-        if req.get_recipient() == USB.request_recipient_device:
+        recipient_type = req.get_recipient()
+        recipient = None
+        index = req.get_index()
+        if recipient_type == USB.request_recipient_device:
             recipient = self
             recipient = self
-        elif req.get_recipient() == USB.request_recipient_interface:
-            recipient = self.configuration.interfaces[req.index]
-        elif req.get_recipient() == USB.request_recipient_endpoint:
-            recipient = self.configuration.endpoints[req.index]
+        elif recipient_type == USB.request_recipient_interface:
+            if index < len(self.configuration.interfaces):
+                recipient = self.configuration.interfaces[index]
+        elif recipient_type == USB.request_recipient_endpoint:
+            recipient = self.endpoints.get(index, None)
+
+        if not recipient:
+            print(self.name, "invalid recipient, stalling")
+            self.maxusb_app.stall_ep0()
+            return
 
         # and then the type
 
         # and then the type
-        if req.get_type() == USB.request_type_standard:
-            handler = recipient.request_handlers[req.request]
-            handler(req)
-        elif req.get_type() == USB.request_type_class:
-            # HACK: evidently, FreeBSD doesn't pay attention to the device
-            # until it sends a GET_STATUS(class) message
-            self.ready = True
+        req_type = req.get_type()
+        handler_entity = None
+        if req_type == USB.request_type_standard:
+            handler_entity = recipient
+        elif req_type == USB.request_type_class:
+            handler_entity = recipient.device_class
+        elif req_type == USB.request_type_vendor:
+            handler_entity = recipient.get_device_vendor()
+
+        if not handler_entity:
+            print(self.name, "invalid handler entity, stalling")
             self.maxusb_app.stall_ep0()
             self.maxusb_app.stall_ep0()
-        elif req.get_type() == USB.request_type_vendor:
+            return
+
+        handler = handler_entity.request_handlers.get(req.request, None)
+
+        if not handler:
+            print(self.name, "invalid handler, stalling")
             self.maxusb_app.stall_ep0()
             self.maxusb_app.stall_ep0()
+            return
+
+        handler(req)
 
     def handle_data_available(self, ep_num, data):
 
     def handle_data_available(self, ep_num, data):
-        if self.ready and ep_num in self.endpoints:
+        if self.state == USB.state_configured and ep_num in self.endpoints:
             endpoint = self.endpoints[ep_num]
             endpoint = self.endpoints[ep_num]
-            endpoint.handler(data)
+            if callable(endpoint.handler):
+                endpoint.handler(data)
 
     def handle_buffer_available(self, ep_num):
 
     def handle_buffer_available(self, ep_num):
-        if self.ready and ep_num in self.endpoints:
+        if self.state == USB.state_configured and ep_num in self.endpoints:
             endpoint = self.endpoints[ep_num]
             endpoint = self.endpoints[ep_num]
-            endpoint.handler()
+            if callable(endpoint.handler):
+                endpoint.handler()
     
     # standard request handlers
     #####################################################
     
     # standard request handlers
     #####################################################
@@ -201,8 +223,7 @@ class USBDevice:
                     + "language 0x%04x, length %d") \
                     % (dtype, dindex, lang, n))
 
                     + "language 0x%04x, length %d") \
                     % (dtype, dindex, lang, n))
 
-        # TODO: handle KeyError
-        response = self.descriptors[dtype]
+        response = self.descriptors.get(dtype, None)
         if callable(response):
             response = response(dindex)
 
         if callable(response):
             response = response(dindex)
 
@@ -214,6 +235,8 @@ class USBDevice:
 
             if self.verbose > 5:
                 print(self.name, "sent", n, "bytes in response")
 
             if self.verbose > 5:
                 print(self.name, "sent", n, "bytes in response")
+        else:
+            self.maxusb_app.stall_ep0()
 
     def handle_get_configuration_descriptor_request(self, num):
         return self.configurations[num].get_descriptor()
 
     def handle_get_configuration_descriptor_request(self, num):
         return self.configurations[num].get_descriptor()
@@ -230,6 +253,11 @@ class USBDevice:
         else:
             # string descriptors start at 1
             s = self.strings[num-1].encode('utf-16')
         else:
             # string descriptors start at 1
             s = self.strings[num-1].encode('utf-16')
+
+            # Linux doesn't like the leading 2-byte Byte Order Mark (BOM);
+            # FreeBSD is okay without it
+            s = s[2:]
+
             d = bytearray([
                     len(s) + 2,     # length of descriptor in bytes
                     3               # descriptor type 3 == string
             d = bytearray([
                     len(s) + 2,     # length of descriptor in bytes
                     3               # descriptor type 3 == string
@@ -318,3 +346,12 @@ class USBDeviceRequest:
     def get_recipient(self):
         return self.request_type & 0x1f
 
     def get_recipient(self):
         return self.request_type & 0x1f
 
+    # meaning of bits in wIndex changes whether we're talking about an
+    # interface or an endpoint (see USB 2.0 spec section 9.3.4)
+    def get_index(self):
+        rec = self.get_recipient()
+        if rec == 1:                # interface
+            return self.index
+        elif rec == 2:              # endpoint
+            return self.index & 0x0f
+