X-Git-Url: http://git.rot13.org/?p=goodfet;a=blobdiff_plain;f=client%2FUSBDevice.py;h=4a2ba1316eae0002b618771799e72a4011a8d0b2;hp=84e8b1ff3a3e84f958af2206308b834843bcfdf3;hb=a2f6b68ebded8f6213a19adf22306e892aa9fff6;hpb=96ebf28fac2e03cfee954db85282eac7b8ec8015;ds=sidebyside diff --git a/client/USBDevice.py b/client/USBDevice.py index 84e8b1f..4a2ba13 100644 --- a/client/USBDevice.py +++ b/client/USBDevice.py @@ -3,6 +3,7 @@ # Contains class definitions for USBDevice and USBDeviceRequest. from USB import * +from USBClass import * class USBDevice: name = "generic device" @@ -42,9 +43,7 @@ class USBDevice: for c in self.configurations: csi = self.get_string_id(c.configuration_string) c.set_configuration_string_index(csi) - - for i in c.interfaces: - i.device = self + c.set_device(self) self.state = USB.state_detached self.ready = False @@ -128,34 +127,57 @@ class USBDevice: print(self.name, "received request", req) # figure out the intended recipient - if req.get_recipient() == USB.request_recipient_device: + recipient_type = req.get_recipient() + recipient = None + index = req.get_index() + if recipient_type == USB.request_recipient_device: recipient = self - elif req.get_recipient() == USB.request_recipient_interface: - recipient = self.configuration.interfaces[req.index] - elif req.get_recipient() == USB.request_recipient_endpoint: - recipient = self.configuration.endpoints[req.index] + elif recipient_type == USB.request_recipient_interface: + if index < len(self.configuration.interfaces): + recipient = self.configuration.interfaces[index] + elif recipient_type == USB.request_recipient_endpoint: + recipient = self.endpoints.get(index, None) + + if not recipient: + print(self.name, "invalid recipient, stalling") + self.maxusb_app.stall_ep0() + return # and then the type - if req.get_type() == USB.request_type_standard: - handler = recipient.request_handlers[req.request] - handler(req) - elif req.get_type() == USB.request_type_class: - # HACK: evidently, FreeBSD doesn't pay attention to the device - # until it sends a GET_STATUS(class) message - self.ready = True + req_type = req.get_type() + handler_entity = None + if req_type == USB.request_type_standard: + handler_entity = recipient + elif req_type == USB.request_type_class: + handler_entity = recipient.device_class + elif req_type == USB.request_type_vendor: + handler_entity = recipient.get_device_vendor() + + if not handler_entity: + print(self.name, "invalid handler entity, stalling") self.maxusb_app.stall_ep0() - elif req.get_type() == USB.request_type_vendor: + return + + handler = handler_entity.request_handlers.get(req.request, None) + + if not handler: + print(self.name, "invalid handler, stalling") self.maxusb_app.stall_ep0() + return + + handler(req) def handle_data_available(self, ep_num, data): - if self.ready and ep_num in self.endpoints: + if self.state == USB.state_configured and ep_num in self.endpoints: endpoint = self.endpoints[ep_num] - endpoint.handler(data) + if callable(endpoint.handler): + endpoint.handler(data) def handle_buffer_available(self, ep_num): - if self.ready and ep_num in self.endpoints: + if self.state == USB.state_configured and ep_num in self.endpoints: endpoint = self.endpoints[ep_num] - endpoint.handler() + if callable(endpoint.handler): + endpoint.handler() # standard request handlers ##################################################### @@ -201,8 +223,7 @@ class USBDevice: + "language 0x%04x, length %d") \ % (dtype, dindex, lang, n)) - # TODO: handle KeyError - response = self.descriptors[dtype] + response = self.descriptors.get(dtype, None) if callable(response): response = response(dindex) @@ -214,6 +235,8 @@ class USBDevice: if self.verbose > 5: print(self.name, "sent", n, "bytes in response") + else: + self.maxusb_app.stall_ep0() def handle_get_configuration_descriptor_request(self, num): return self.configurations[num].get_descriptor() @@ -230,6 +253,11 @@ class USBDevice: else: # string descriptors start at 1 s = self.strings[num-1].encode('utf-16') + + # Linux doesn't like the leading 2-byte Byte Order Mark (BOM); + # FreeBSD is okay without it + s = s[2:] + d = bytearray([ len(s) + 2, # length of descriptor in bytes 3 # descriptor type 3 == string @@ -318,3 +346,12 @@ class USBDeviceRequest: def get_recipient(self): return self.request_type & 0x1f + # meaning of bits in wIndex changes whether we're talking about an + # interface or an endpoint (see USB 2.0 spec section 9.3.4) + def get_index(self): + rec = self.get_recipient() + if rec == 1: # interface + return self.index + elif rec == 2: # endpoint + return self.index & 0x0f +