Address channel connection review comments

This commit is contained in:
hetaoBackend
2026-06-12 09:42:38 +08:00
parent e0cb0df5da
commit 274cae62d8
6 changed files with 66 additions and 14 deletions
@@ -5,11 +5,12 @@ from __future__ import annotations
import base64
import hashlib
import json
import logging
import uuid
from datetime import UTC, datetime
from typing import Any
from cryptography.fernet import Fernet
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@@ -21,6 +22,8 @@ from deerflow.persistence.channel_connections.model import (
)
from deerflow.utils.time import coerce_iso
logger = logging.getLogger(__name__)
class ChannelCredentialCipher:
"""Encrypts provider credentials before they are persisted."""
@@ -195,16 +198,23 @@ class ChannelConnectionRepository:
row = await session.get(ChannelCredentialRow, connection_id)
if row is None:
return None
extra_raw = self._cipher.decrypt_text(row.encrypted_extra_json)
return {
"connection_id": row.connection_id,
"access_token": self._cipher.decrypt_text(row.encrypted_access_token),
"refresh_token": self._cipher.decrypt_text(row.encrypted_refresh_token),
"token_type": row.token_type,
"expires_at": self._coerce_datetime(row.expires_at),
"refresh_expires_at": self._coerce_datetime(row.refresh_expires_at),
"extra": json.loads(extra_raw) if extra_raw else {},
}
try:
extra_raw = self._cipher.decrypt_text(row.encrypted_extra_json)
return {
"connection_id": row.connection_id,
"access_token": self._cipher.decrypt_text(row.encrypted_access_token),
"refresh_token": self._cipher.decrypt_text(row.encrypted_refresh_token),
"token_type": row.token_type,
"expires_at": self._coerce_datetime(row.expires_at),
"refresh_expires_at": self._coerce_datetime(row.refresh_expires_at),
"extra": json.loads(extra_raw) if extra_raw else {},
}
except (InvalidToken, UnicodeError, json.JSONDecodeError):
logger.warning(
"Unable to decrypt channel connection credentials; treating credentials as unavailable",
exc_info=True,
)
return None
@staticmethod
def hash_state(state: str) -> str: