massive commit to get umass working, only significant known bug is that changes don...
[goodfet] / client / USBDevice.py
index e25fea8..49ba0a4 100644 (file)
@@ -3,6 +3,7 @@
 # Contains class definitions for USBDevice and USBDeviceRequest.
 
 from USB import *
+from USBClass import *
 
 class USBDevice:
     name = "generic device"
@@ -42,9 +43,7 @@ class USBDevice:
         for c in self.configurations:
             csi = self.get_string_id(c.configuration_string)
             c.set_configuration_string_index(csi)
-
-            for i in c.interfaces:
-                i.device = self
+            c.set_device(self)
 
         self.state = USB.state_detached
         self.ready = False
@@ -128,38 +127,57 @@ class USBDevice:
             print(self.name, "received request", req)
 
         # figure out the intended recipient
-        if req.get_recipient() == USB.request_recipient_device:
+        recipient_type = req.get_recipient()
+        recipient = None
+        index = req.get_index()
+        if recipient_type == USB.request_recipient_device:
             recipient = self
-        elif req.get_recipient() == USB.request_recipient_interface:
-            recipient = self.configuration.interfaces[req.index]
-        elif req.get_recipient() == USB.request_recipient_endpoint:
-            try:
-                recipient = self.endpoints[req.index]
-            except KeyError:
-                self.maxusb_app.stall_ep0()
-                return
+        elif recipient_type == USB.request_recipient_interface:
+            if index < len(self.configuration.interfaces):
+                recipient = self.configuration.interfaces[index]
+        elif recipient_type == USB.request_recipient_endpoint:
+            recipient = self.endpoints.get(index, None)
+
+        if not recipient:
+            print(self.name, "invalid recipient, stalling")
+            self.maxusb_app.stall_ep0()
+            return
 
         # and then the type
-        if req.get_type() == USB.request_type_standard:
-            handler = recipient.request_handlers[req.request]
-            handler(req)
-        elif req.get_type() == USB.request_type_class:
-            # HACK: evidently, FreeBSD doesn't pay attention to the device
-            # until it sends a GET_STATUS(class) message
-            self.ready = True
+        req_type = req.get_type()
+        handler_entity = None
+        if req_type == USB.request_type_standard:
+            handler_entity = recipient
+        elif req_type == USB.request_type_class:
+            handler_entity = recipient.device_class
+        elif req_type == USB.request_type_vendor:
+            handler_entity = recipient.get_device_vendor()
+
+        if not handler_entity:
+            print(self.name, "invalid handler entity, stalling")
             self.maxusb_app.stall_ep0()
-        elif req.get_type() == USB.request_type_vendor:
+            return
+
+        handler = handler_entity.request_handlers.get(req.request, None)
+
+        if not handler:
+            print(self.name, "invalid handler, stalling")
             self.maxusb_app.stall_ep0()
+            return
+
+        handler(req)
 
     def handle_data_available(self, ep_num, data):
         if self.state == USB.state_configured and ep_num in self.endpoints:
             endpoint = self.endpoints[ep_num]
-            endpoint.handler(data)
+            if callable(endpoint.handler):
+                endpoint.handler(data)
 
     def handle_buffer_available(self, ep_num):
         if self.state == USB.state_configured and ep_num in self.endpoints:
             endpoint = self.endpoints[ep_num]
-            endpoint.handler()
+            if callable(endpoint.handler):
+                endpoint.handler()
     
     # standard request handlers
     #####################################################
@@ -327,3 +345,12 @@ class USBDeviceRequest:
     def get_recipient(self):
         return self.request_type & 0x1f
 
+    # meaning of bits in wIndex changes whether we're talking about an
+    # interface or an endpoint (see USB 2.0 spec section 9.3.4)
+    def get_index(self):
+        rec = self.get_recipient()
+        if rec == 1:                # interface
+            return self.index
+        elif rec == 2:              # endpoint
+            return self.index & 0x0f
+