这里我们只实现标准的base64, 补充位用=填充
编码
下面是base64字符的对照表, 因为base64编码是将6bit表示成8bit,所以在原来的基础上会增长1/3, 另外2^6=64,这也是为什么这个表会有64个索引 A-Za-Z0-9+/一共是64个字符, 编码后如果不够四个字符的倍数需要填充
eg:
I编码到SQ(SQ==带填充) AM编码到QU0(QU0=带填充) TJM编码到VEpN(VEpN带填充) 假设我们有一个程序,base64对单词进行编码,连接它们并通过网络发送它们。它对"i"、"am"和"tjm"进行编码,将结果夹在一起而不加填充,并将其传输 传输的数据是SQQU0VEpN, 但是解码之后是I\x04\x14\xd1Q), 就是因为没有填充,错误的解码导致的(Q, 0, 这些字符的数据都丢失了, 所以用了后面的位数,错误的解码导致错误的结果)
- 首先是对输入类型进行校验,是否是字节流
- 由于可能需要补=, 所以分开将不需要补和需要补的部分分开进行转换
- 对于不需要填补=字符的,每次取三个字节, 转化成四个字符码值,然后查表得到字符
- 对于需要填补=字符的, 可以进行填补3 - 剩余字节个字节
- 对于解码类似,将四个字符转换成三个字节
import string
from typing import Union
class EncodeError(Exception):
""" python encode error """
pass
class DecodeError(Exception):
""" python decode error """
pass
class Base64(object):
MIN_LENGTH: int = 4
MODULO_NUMBER: int = 4
CHARACTER_TABLE: str = string.ascii_uppercase + string.ascii_lowercase + string.digits + '+/'
COVERING_CHARACTER: str = '='
ZERO_CHARACTER = '0'
@staticmethod
def encode(s: Union[bytes, bytearray]) -> str:
if not isinstance(s, (bytes, bytearray)):
raise TypeError("the input s type is error!")
"""
<模板字符串>.format(<逗号分隔的参数>)
模板字符串的格式是{<参数序号>: <格式控制标记>}
<格式控制标记>包括<填充><对齐><宽度><,><.精度><类型>
这里是取到8位bit,不足补0
"""
base64_bytes = ['{:0>8}'.format(bin(character).replace('0b', '')) for character in s]
result = list()
translate_count = len(base64_bytes) // 3
left_count = len(base64_bytes) % 3
covering_count = 3 - left_count
translate_part = base64_bytes[0: 3 * translate_count]
while translate_part:
origin_unit = ''.join(translate_part[0:3])
translate_part_unit = [int(origin_unit[x: x + 6], 2) for x in range(0, 19, 6)]
result.extend([Base64.CHARACTER_TABLE[i] for i in translate_part_unit])
translate_part = translate_part[3:]
if left_count:
covering_unit = ''.join(base64_bytes[3 * translate_count:]) + covering_count * '0' * 8
translate_count = left_count + 1
left_part_unit = [int(covering_unit[x: x + 6], 2) for x in range(0, 19, 6)][:translate_count]
result.extend([Base64.CHARACTER_TABLE[i] for i in left_part_unit])
result.append(covering_count * Base64.COVERING_CHARACTER)
return ''.join(result)
@staticmethod
def __valid_base64_str(s: str) -> bool:
if len(s) < Base64.MIN_LENGTH and len(s) % Base64.MODULO_NUMBER != 0:
return False
for index, character in enumerate(s):
if character not in (Base64.CHARACTER_TABLE + Base64.COVERING_CHARACTER):
return False
if character in Base64.COVERING_CHARACTER:
if len(s[index:]) > 2:
return False
if index == 0:
continue
if s[index - 1] in Base64.COVERING_CHARACTER and character not in Base64.COVERING_CHARACTER:
return False
return True
@staticmethod
def decode(s: str):
if not Base64.__valid_base64_str(s):
raise DecodeError("input s is invalid!")
base64_bytes = ['{:0>6}'.format(bin(Base64.CHARACTER_TABLE.index(character)).replace('0b', ''))
for character in s if character != Base64.COVERING_CHARACTER]
result = bytearray()
translate_count = len(base64_bytes) // 4
covering_count = len(base64_bytes) % 4
left_count = covering_count - 1
translate_part = base64_bytes[0: 4 * translate_count]
while translate_part:
origin_unit = ''.join(translate_part[0:4])
translate_part_unit = [int(origin_unit[x: x + 8], 2) for x in range(0, 17, 8)]
result.extend(translate_part_unit)
translate_part = translate_part[4:]
if covering_count:
covering_unit = ''.join(base64_bytes[translate_count * 4:])
left_part_unit = [int(covering_unit[x * 8: x * 8 + 8], 2) for x in range(left_count)]
result.extend(left_part_unit)
return result
if __name__ == '__main__':
print(Base64.encode(b'i\xb7\x1d\xfb\xef\xff'))
print(Base64.decode("abcd++//"))
assert Base64.encode(b'i\xb7\x1d\xfb\xef\xff') == "abcd++//"
assert Base64.decode("abcd++//") == b'i\xb7\x1d\xfb\xef\xff'
|