diff --git a/src/acct/accounting_flow.py b/src/acct/accounting_flow.py index 08287113..0b35ce37 100644 --- a/src/acct/accounting_flow.py +++ b/src/acct/accounting_flow.py @@ -16,8 +16,8 @@ def accounting(cls, request: AcctRequest, acct_user: AcctUser): log.debug(f'IN: {request.iut}|{acct_user.outer_username}|{acct_user.user_mac}') # 查找用户密码 - user = Account.get(username=acct_user.outer_username) - if not user: + account = Account.get(username=acct_user.outer_username) + if not account: return # 每隔x秒清理会话 diff --git a/src/auth/chap_flow.py b/src/auth/chap_flow.py index ef623b5c..1671b9f4 100644 --- a/src/auth/chap_flow.py +++ b/src/auth/chap_flow.py @@ -16,12 +16,13 @@ def authenticate(cls, request: AuthRequest, auth_user: AuthUser): session = BaseSession(auth_user=auth_user) # 查找用户密码 account_name = session.auth_user.outer_username - user = Account.get(username=account_name) - if not user: + account: Account = Account.get(username=account_name) + if not account: raise AccessReject() else: # 保存用户密码 - session.auth_user.set_user_password(user.password) + session.auth_user.set_user_password(account.password) + session.auth_user.set_user_speed(account.speed) def is_correct_password() -> bool: return Chap.is_correct_challenge_value(request=request, user_password=session.auth_user.user_password) @@ -42,7 +43,7 @@ def access_accept(cls, request: AuthRequest, session: BaseSession): request.ap_mac, ] log.info(f'OUT: accept|{"|".join(data)}|') - reply = AuthResponse.create_access_accept(request=request) + reply = AuthResponse.create_access_accept(request=request, session=session) return request.reply_to(reply) @classmethod diff --git a/src/auth/eap_peap_gtc_flow.py b/src/auth/eap_peap_gtc_flow.py index d3f2afbd..98a366af 100644 --- a/src/auth/eap_peap_gtc_flow.py +++ b/src/auth/eap_peap_gtc_flow.py @@ -230,12 +230,13 @@ def peap_challenge_gtc_password(cls, request: AuthRequest, eap: EapPacket, peap: session.auth_user.set_peap_username(account_name) # 查找用户密码 - user = Account.get(username=account_name) - if not user: + account = Account.get(username=account_name) + if not account: raise AccessReject() else: # 保存用户密码 - session.auth_user.set_user_password(user.radius_password) + session.auth_user.set_user_password(account.radius_password) + session.auth_user.set_user_speed(account.speed) # 返回数据 response_data = b'Password' @@ -262,7 +263,7 @@ def peap_challenge_success(cls, request: AuthRequest, eap: EapPacket, peap: EapP tls_decrypt_data = libhostapd.decrypt(session.tls_connection, peap.tls_data) eap_password = EapPacket.parse(packet=tls_decrypt_data) auth_password = eap_password.type_data.decode() - log.debug(f'PEAP account: {session.auth_user.peap_username}, packet_password: {auth_password}') + log.debug(f'PEAP user: {session.auth_user.peap_username}, packet_password: {auth_password}') def is_correct_password() -> bool: return session.auth_user.user_password == auth_password @@ -309,8 +310,8 @@ def access_accept(cls, request: AuthRequest, session: EapPeapSession): request.ap_mac, ] log.info(f'OUT: accept|{"|".join(data)}|') - reply = AuthResponse.create_access_accept(request=request) - reply['State'] = session.session_id.encode() # octets 传入 bytes + reply = AuthResponse.create_access_accept(request=request, session=session) + reply['State'] = session.session_id.encode() # octets log.debug(f'msk: {session.msk}, secret: {reply.secret}, authenticator: {request.authenticator}') reply['MS-MPPE-Recv-Key'], reply['MS-MPPE-Send-Key'] = create_mppe_recv_key_send_key(session.msk, reply.secret, request.authenticator) reply['EAP-Message'] = struct.pack('!B B H', EapPacket.CODE_EAP_SUCCESS, session.current_eap_id-1, 4) # eap_id抓包是这样, 不要惊讶! diff --git a/src/auth/eap_peap_mschapv2_flow.py b/src/auth/eap_peap_mschapv2_flow.py index 86dbed95..84db032f 100644 --- a/src/auth/eap_peap_mschapv2_flow.py +++ b/src/auth/eap_peap_mschapv2_flow.py @@ -239,12 +239,13 @@ def peap_challenge_mschapv2_random(cls, request: AuthRequest, eap: EapPacket, pe # 保存用户名 session.auth_user.set_peap_username(account_name) # 查找用户密码 - user = Account.get(username=account_name) - if not user: + account = Account.get(username=account_name) + if not account: raise AccessReject() else: # 保存用户密码 - session.auth_user.set_user_password(user.radius_password) + session.auth_user.set_user_password(account.radius_password) + session.auth_user.set_user_speed(account.speed) # 返回数据 # MSCHAPV2_OP_CHALLENGE(01) + 与EAP_id相同(07) + MSCHAPV2_OP 到结束的长度(00 1c) + @@ -436,7 +437,7 @@ def access_accept(cls, request: AuthRequest, session: EapPeapSession): request.ap_mac, ] log.info(f'OUT: accept|{"|".join(data)}|') - reply = AuthResponse.create_access_accept(request=request) + reply = AuthResponse.create_access_accept(request=request, session=session) reply['State'] = session.session_id.encode() log.trace(f'msk: {session.msk}, secret: {reply.secret}, authenticator: {request.authenticator}') reply['MS-MPPE-Recv-Key'], reply['MS-MPPE-Send-Key'] = create_mppe_recv_key_send_key(session.msk, reply.secret, request.authenticator) diff --git a/src/auth/mschap_flow.py b/src/auth/mschap_flow.py index 07220275..e1952ec2 100644 --- a/src/auth/mschap_flow.py +++ b/src/auth/mschap_flow.py @@ -31,6 +31,7 @@ def authenticate(cls, request: AuthRequest, auth_user: AuthUser): raise AccessReject() # 保存用户密码 session.auth_user.set_user_password(account.radius_password) + session.auth_user.set_user_speed(account.speed) ################ username = session.auth_user.outer_username @@ -109,6 +110,6 @@ def access_accept(cls, request: AuthRequest, session: BaseSession): request.ap_mac, ] log.info(f'OUT: accept|{"|".join(data)}|') - reply = AuthResponse.create_access_accept(request=request) + reply = AuthResponse.create_access_accept(request=request, session=session) reply['MS-CHAP2-Success'] = session.extra['MS-CHAP2-Success'] return request.reply_to(reply) diff --git a/src/auth/pap_flow.py b/src/auth/pap_flow.py index 265e9641..80879041 100644 --- a/src/auth/pap_flow.py +++ b/src/auth/pap_flow.py @@ -70,5 +70,5 @@ def access_accept(cls, request: AuthRequest, session: BaseSession): request.ap_mac, ] log.info(f'OUT: accept|{"|".join(data)}|') - reply = AuthResponse.create_access_accept(request=request) + reply = AuthResponse.create_access_accept(request=request, session=session) return request.reply_to(reply) diff --git a/src/child_pyrad/packet.py b/src/child_pyrad/packet.py index 3a29eae6..6ed79b85 100644 --- a/src/child_pyrad/packet.py +++ b/src/child_pyrad/packet.py @@ -6,6 +6,7 @@ from .eap_peap_packet import EapPeapPacket from controls.stat import ApStat, UserStat from loguru import logger as log +from auth.session import BaseSession from settings import ACCOUNTING_INTERVAL @@ -88,17 +89,23 @@ class AuthResponse(AuthPacket): # 使用父类初始化自己 @classmethod - def create_access_accept(cls, request: AuthRequest) -> AuthPacket: + def create_access_accept(cls, request: AuthRequest, session: BaseSession) -> AuthPacket: UserStat.report_user_online(username=request.username, user_mac=request.user_mac, ap_mac=request.ap_mac) ApStat.report_ap_online(username=request.username, ap_mac=request.ap_mac) # reply = request.create_reply(code=Packet.CODE_ACCESS_ACCEPT) - # reply['Session-Timeout'] = 600 # 用户可用的剩余时间 - # reply['H3C-Input-Peak-Rate'] = int(self.bandwidth_max_up) # 用户到NAS的峰值速率, 以bps为单位. 1/8字节每秒 - # reply['H3C-Output-Peak-Rate'] = int(self.bandwidth_max_down) # NAS到用户的峰值速率, 以bps为单位. 1/8字节每秒 reply['Idle-Timeout'] = 86400 # 用户的闲置切断时间 reply['Acct-Interim-Interval'] = ACCOUNTING_INTERVAL - # reply['Class'] = '\x7f'.join(('EAP-PEAP', session.auth_user.peap_username, session.session_id)) # Access-Accept发送给AC, AC在计费报文内会携带Class值上报 + + if session.auth_user.user_speed: + # 上传速度. 用户到NAS的峰值速率, 以bps为单位. 1/8字节每秒 + reply['H3C-Input-Peak-Rate'] = session.auth_user.user_speed * 1024 * 1024 * 8 / 4 + # 下载速度. NAS到用户的峰值速率, 以bps为单位. 1/8字节每秒 + reply['H3C-Output-Peak-Rate'] = session.auth_user.user_speed * 1024 * 1024 * 8 + reply['Filter-Id'] = f'pay_user_{session.auth_user.user_speed}m' + # reply['Session-Timeout'] = 600 # 用户可用的剩余时间 + # reply['Class'] = '\x7f'.join(('EAP-PEAP', session.auth_user.peap_username, session.session_id)) # Access-Accept发送给AC, AC在计费报文内会携带Class值上报 + return reply @classmethod diff --git a/src/controls/user.py b/src/controls/user.py index dc50d5b8..c4c0e905 100644 --- a/src/controls/user.py +++ b/src/controls/user.py @@ -8,8 +8,9 @@ def __init__(self, request: AuthRequest): # 提取报文 self.outer_username: str = request.username self.peap_username: str = '' - self.user_mac = request.user_mac # mac地址 + self.user_mac = request.user_mac # mac地址 self.user_password: str = '' + self.user_speed: int = 0 # 兆 self.server_challenge: bytes = b'' self.peer_challenge: bytes = b'' @@ -19,6 +20,9 @@ def set_peap_username(self, account_name: str): def set_user_password(self, password: str): self.user_password = password + def set_user_speed(self, speed: int): + self.user_speed = speed + def set_server_challenge(self, server_challenge: bytes): self.server_challenge = server_challenge diff --git a/src/models/account.py b/src/models/account.py index 2e0ab7de..4bf142b4 100644 --- a/src/models/account.py +++ b/src/models/account.py @@ -22,6 +22,7 @@ class Role(ModelEnum): password = Column(String(255)) radius_password = Column(String(255)) role = Column(String(32)) + speed = Column(Integer) expired_at = Column(DateTime) def __repr__(self): diff --git a/src/processor/sync_user.py b/src/processor/sync_user.py index 37a775e6..cb22d94c 100755 --- a/src/processor/sync_user.py +++ b/src/processor/sync_user.py @@ -38,19 +38,19 @@ def run(self): # expired_at_dt = parse(expired_at) # datetime 类型 expired_at_str = expired_at_dt.strftime('%Y-%m-%d %H:%M:%S') # 字符串类型 - user = session.query(Account).filter(Account.username == username).first() - if not user: - new_user = Account(username=username, password=password, expired_at=expired_at_dt) - session.add(new_user) + account = session.query(Account).filter(Account.username == username).first() + if not account: + new_account = Account(username=username, password=password, expired_at=expired_at_dt) + session.add(new_account) session.commit() - log.info(f'insert user: {username}') + log.info(f'insert account: {username}') else: # sync 同步用户 - if user.expired_at.strftime('%Y-%m-%d %H:%M:%S') != expired_at_str or user.password != password: - user.expired_at = expired_at_dt - user.password = password + if account.expired_at.strftime('%Y-%m-%d %H:%M:%S') != expired_at_str or account.password != password: + account.expired_at = expired_at_dt + account.password = password session.commit() - log.info(f'update user: {user.username}') + log.info(f'update account: {account.username}') TaskLoop().start()