c1b7f1d189
* feat(middleware): inject dynamic context via DynamicContextMiddleware
Move memory and current date out of the system prompt and into a
dedicated <system-reminder> HumanMessage injected once per session
(frozen-snapshot pattern) via a new DynamicContextMiddleware.
This keeps the system prompt byte-exact across all users and sessions,
enabling maximum Anthropic/Bedrock prefix-cache reuse.
Key design decisions:
- ID-swap technique: reminder takes the first HumanMessage's ID
(replacing it in-place via add_messages), original content gets a
derived `{id}__user` ID (appended after). Preserves correct ordering.
- hide_from_ui: True on reminder messages so frontend filters them out.
- Midnight crossing: date-update reminder injected before the current
turn's HumanMessage when the conversation spans midnight.
- INFO-level logging for production diagnostics.
Also adds prompt-caching breakpoint budget enforcement tests and
updates ClaudeChatModel docs to reference the new pattern.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* feat(token-usage): log input/output token detail breakdown in middleware
Extend the LLM token usage log line to include input_token_details and
output_token_details (cache_creation, cache_read, reasoning, audio, etc.)
when present. Adds tests covering Anthropic cache detail logging from
both usage_metadata and response_metadata.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* fix: fix nginx
* fix(middleware): always inject date; gate memory on injection_enabled
Date injection is now unconditional — it is part of the static system
prompt replacement and should always be present. Memory injection
remains gated by `memory.injection_enabled` in the app config.
Previously the entire DynamicContextMiddleware was skipped when
injection_enabled was False, which also suppressed the date.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* fix(lint): format files and correct test assertions for token usage middleware
- ruff format dynamic_context_middleware.py and test_claude_provider_prompt_caching.py
- Remove unused pytest import from test_dynamic_context_middleware.py
- Fix two tests that asserted response_metadata fallback logic that
doesn't exist: replace with tests that match actual middleware behavior
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* fix(middleware): address Copilot review comments on DynamicContextMiddleware
- Use additional_kwargs flag for reminder detection instead of content
substring matching, so user messages containing '<system-reminder>'
are not mistakenly treated as injected reminders
- Generate stable UUID when original HumanMessage.id is None to prevent
ambiguous 'None__user' derived IDs and message collisions
- Downgrade per-turn no-op log to DEBUG; keep actual injection events at INFO
- Add two new tests: missing-id UUID fallback and user-text false-positive
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
---------
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
364 lines
15 KiB
Python
364 lines
15 KiB
Python
"""Custom Claude provider with OAuth Bearer auth, prompt caching, and smart thinking.
|
|
|
|
Supports two authentication modes:
|
|
1. Standard API key (x-api-key header) — default ChatAnthropic behavior
|
|
2. Claude Code OAuth token (Authorization: Bearer header)
|
|
- Detected by sk-ant-oat prefix
|
|
- Requires anthropic-beta: oauth-2025-04-20,claude-code-20250219
|
|
- Requires billing header in system prompt for all OAuth requests
|
|
|
|
Auto-loads credentials from explicit runtime handoff:
|
|
- $ANTHROPIC_API_KEY environment variable
|
|
- $CLAUDE_CODE_OAUTH_TOKEN or $ANTHROPIC_AUTH_TOKEN
|
|
- $CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR
|
|
- $CLAUDE_CODE_CREDENTIALS_PATH
|
|
- ~/.claude/.credentials.json
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import socket
|
|
import time
|
|
import uuid
|
|
from typing import Any
|
|
|
|
import anthropic
|
|
from langchain_anthropic import ChatAnthropic
|
|
from langchain_core.messages import BaseMessage
|
|
from pydantic import PrivateAttr
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MAX_RETRIES = 3
|
|
THINKING_BUDGET_RATIO = 0.8
|
|
|
|
# Billing header required by Anthropic API for OAuth token access.
|
|
# Must be the first system prompt block. Format mirrors Claude Code CLI.
|
|
# Override with ANTHROPIC_BILLING_HEADER env var if the hardcoded version drifts.
|
|
_DEFAULT_BILLING_HEADER = "x-anthropic-billing-header: cc_version=2.1.85.351; cc_entrypoint=cli; cch=6c6d5;"
|
|
OAUTH_BILLING_HEADER = os.environ.get("ANTHROPIC_BILLING_HEADER", _DEFAULT_BILLING_HEADER)
|
|
|
|
|
|
class ClaudeChatModel(ChatAnthropic):
|
|
"""ChatAnthropic with OAuth Bearer auth, prompt caching, and smart thinking.
|
|
|
|
Config example:
|
|
- name: claude-sonnet-4.6
|
|
use: deerflow.models.claude_provider:ClaudeChatModel
|
|
model: claude-sonnet-4-6
|
|
max_tokens: 16384
|
|
enable_prompt_caching: true
|
|
"""
|
|
|
|
# Custom fields
|
|
enable_prompt_caching: bool = True
|
|
prompt_cache_size: int = 3
|
|
auto_thinking_budget: bool = True
|
|
retry_max_attempts: int = MAX_RETRIES
|
|
_is_oauth: bool = PrivateAttr(default=False)
|
|
_oauth_access_token: str = PrivateAttr(default="")
|
|
|
|
model_config = {"arbitrary_types_allowed": True}
|
|
|
|
def _validate_retry_config(self) -> None:
|
|
if self.retry_max_attempts < 1:
|
|
raise ValueError("retry_max_attempts must be >= 1")
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""Auto-load credentials and configure OAuth if needed."""
|
|
from pydantic import SecretStr
|
|
|
|
from deerflow.models.credential_loader import (
|
|
OAUTH_ANTHROPIC_BETAS,
|
|
is_oauth_token,
|
|
load_claude_code_credential,
|
|
)
|
|
|
|
self._validate_retry_config()
|
|
|
|
# Extract actual key value (SecretStr.str() returns '**********')
|
|
current_key = ""
|
|
if self.anthropic_api_key:
|
|
if hasattr(self.anthropic_api_key, "get_secret_value"):
|
|
current_key = self.anthropic_api_key.get_secret_value()
|
|
else:
|
|
current_key = str(self.anthropic_api_key)
|
|
|
|
# Try the explicit Claude Code OAuth handoff sources if no valid key.
|
|
if not current_key or current_key in ("your-anthropic-api-key",):
|
|
cred = load_claude_code_credential()
|
|
if cred:
|
|
current_key = cred.access_token
|
|
logger.info(f"Using Claude Code CLI credential (source: {cred.source})")
|
|
else:
|
|
logger.warning("No Anthropic API key or explicit Claude Code OAuth credential found.")
|
|
|
|
# Detect OAuth token and configure Bearer auth
|
|
if is_oauth_token(current_key):
|
|
self._is_oauth = True
|
|
self._oauth_access_token = current_key
|
|
# Set the token as api_key temporarily (will be swapped to auth_token on client)
|
|
self.anthropic_api_key = SecretStr(current_key)
|
|
# Add required beta headers for OAuth
|
|
self.default_headers = {
|
|
**(self.default_headers or {}),
|
|
"anthropic-beta": OAUTH_ANTHROPIC_BETAS,
|
|
}
|
|
# OAuth tokens have a limit of 4 cache_control blocks — disable prompt caching
|
|
self.enable_prompt_caching = False
|
|
logger.info("OAuth token detected — will use Authorization: Bearer header")
|
|
else:
|
|
if current_key:
|
|
self.anthropic_api_key = SecretStr(current_key)
|
|
|
|
# Ensure api_key is SecretStr
|
|
if isinstance(self.anthropic_api_key, str):
|
|
self.anthropic_api_key = SecretStr(self.anthropic_api_key)
|
|
|
|
super().model_post_init(__context)
|
|
|
|
# Patch clients immediately after creation for OAuth Bearer auth.
|
|
# This must happen after super() because clients are lazily created.
|
|
if self._is_oauth:
|
|
self._patch_client_oauth(self._client)
|
|
self._patch_client_oauth(self._async_client)
|
|
|
|
def _patch_client_oauth(self, client: Any) -> None:
|
|
"""Swap api_key → auth_token on an Anthropic SDK client for OAuth Bearer auth."""
|
|
if hasattr(client, "api_key") and hasattr(client, "auth_token"):
|
|
client.api_key = None
|
|
client.auth_token = self._oauth_access_token
|
|
|
|
def _get_request_payload(
|
|
self,
|
|
input_: Any,
|
|
*,
|
|
stop: list[str] | None = None,
|
|
**kwargs: Any,
|
|
) -> dict:
|
|
"""Override to inject prompt caching, thinking budget, and OAuth billing."""
|
|
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
|
|
|
if self._is_oauth:
|
|
self._apply_oauth_billing(payload)
|
|
|
|
if self.enable_prompt_caching:
|
|
self._apply_prompt_caching(payload)
|
|
|
|
if self.auto_thinking_budget:
|
|
self._apply_thinking_budget(payload)
|
|
|
|
return payload
|
|
|
|
def _apply_oauth_billing(self, payload: dict) -> None:
|
|
"""Inject the billing header block required for all OAuth requests.
|
|
|
|
The billing block is always placed first in the system list, removing any
|
|
existing occurrence to avoid duplication or out-of-order positioning.
|
|
"""
|
|
billing_block = {"type": "text", "text": OAUTH_BILLING_HEADER}
|
|
|
|
system = payload.get("system")
|
|
if isinstance(system, list):
|
|
# Remove any existing billing blocks, then insert a single one at index 0.
|
|
filtered = [b for b in system if not (isinstance(b, dict) and OAUTH_BILLING_HEADER in b.get("text", ""))]
|
|
payload["system"] = [billing_block] + filtered
|
|
elif isinstance(system, str):
|
|
if OAUTH_BILLING_HEADER in system:
|
|
payload["system"] = [billing_block]
|
|
else:
|
|
payload["system"] = [billing_block, {"type": "text", "text": system}]
|
|
else:
|
|
payload["system"] = [billing_block]
|
|
|
|
# Add metadata.user_id required by the API for OAuth billing validation
|
|
if not isinstance(payload.get("metadata"), dict):
|
|
payload["metadata"] = {}
|
|
if "user_id" not in payload["metadata"]:
|
|
# Generate a stable device_id from the machine's hostname
|
|
hostname = socket.gethostname()
|
|
device_id = hashlib.sha256(f"deerflow-{hostname}".encode()).hexdigest()
|
|
session_id = str(uuid.uuid4())
|
|
payload["metadata"]["user_id"] = json.dumps(
|
|
{
|
|
"device_id": device_id,
|
|
"account_uuid": "deerflow",
|
|
"session_id": session_id,
|
|
}
|
|
)
|
|
|
|
def _apply_prompt_caching(self, payload: dict) -> None:
|
|
"""Apply ephemeral cache_control to system, recent messages, and last tool definition.
|
|
|
|
Uses a budget of MAX_CACHE_BREAKPOINTS (4) breakpoints — the hard limit
|
|
enforced by both the Anthropic API and AWS Bedrock. Breakpoints are
|
|
placed on the *last* eligible blocks because later breakpoints cover a
|
|
larger prefix and yield better cache hit rates.
|
|
|
|
The system prompt is expected to be fully static (no per-user memory or
|
|
current date). Dynamic context is injected per-turn via
|
|
DynamicContextMiddleware as a <system-reminder> in the first HumanMessage.
|
|
"""
|
|
MAX_CACHE_BREAKPOINTS = 4
|
|
|
|
# Collect candidate blocks in document order:
|
|
# 1. system text blocks
|
|
# 2. content blocks of the last prompt_cache_size messages
|
|
# 3. the last tool definition
|
|
candidates: list[dict] = []
|
|
|
|
# 1. System blocks
|
|
system = payload.get("system")
|
|
if system and isinstance(system, list):
|
|
for block in system:
|
|
if isinstance(block, dict) and block.get("type") == "text":
|
|
candidates.append(block)
|
|
elif system and isinstance(system, str):
|
|
new_block: dict = {"type": "text", "text": system}
|
|
payload["system"] = [new_block]
|
|
candidates.append(new_block)
|
|
|
|
# 2. Recent message blocks
|
|
messages = payload.get("messages", [])
|
|
cache_start = max(0, len(messages) - self.prompt_cache_size)
|
|
for i in range(cache_start, len(messages)):
|
|
msg = messages[i]
|
|
if not isinstance(msg, dict):
|
|
continue
|
|
content = msg.get("content")
|
|
if isinstance(content, list):
|
|
for block in content:
|
|
if isinstance(block, dict):
|
|
candidates.append(block)
|
|
elif isinstance(content, str) and content:
|
|
new_block = {"type": "text", "text": content}
|
|
msg["content"] = [new_block]
|
|
candidates.append(new_block)
|
|
|
|
# 3. Last tool definition
|
|
tools = payload.get("tools", [])
|
|
if tools and isinstance(tools[-1], dict):
|
|
candidates.append(tools[-1])
|
|
|
|
# Apply cache_control only to the last MAX_CACHE_BREAKPOINTS candidates
|
|
# to stay within the API limit.
|
|
for block in candidates[-MAX_CACHE_BREAKPOINTS:]:
|
|
block["cache_control"] = {"type": "ephemeral"}
|
|
|
|
def _apply_thinking_budget(self, payload: dict) -> None:
|
|
"""Auto-allocate thinking budget (80% of max_tokens)."""
|
|
thinking = payload.get("thinking")
|
|
if not thinking or not isinstance(thinking, dict):
|
|
return
|
|
if thinking.get("type") != "enabled":
|
|
return
|
|
if thinking.get("budget_tokens"):
|
|
return
|
|
|
|
max_tokens = payload.get("max_tokens", 8192)
|
|
thinking["budget_tokens"] = int(max_tokens * THINKING_BUDGET_RATIO)
|
|
|
|
@staticmethod
|
|
def _strip_cache_control(payload: dict) -> None:
|
|
"""Remove cache_control markers before OAuth requests reach Anthropic."""
|
|
for section in ("system", "messages"):
|
|
items = payload.get(section)
|
|
if not isinstance(items, list):
|
|
continue
|
|
for item in items:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
item.pop("cache_control", None)
|
|
content = item.get("content")
|
|
if isinstance(content, list):
|
|
for block in content:
|
|
if isinstance(block, dict):
|
|
block.pop("cache_control", None)
|
|
|
|
tools = payload.get("tools")
|
|
if isinstance(tools, list):
|
|
for tool in tools:
|
|
if isinstance(tool, dict):
|
|
tool.pop("cache_control", None)
|
|
|
|
def _create(self, payload: dict) -> Any:
|
|
if self._is_oauth:
|
|
self._strip_cache_control(payload)
|
|
return super()._create(payload)
|
|
|
|
async def _acreate(self, payload: dict) -> Any:
|
|
if self._is_oauth:
|
|
self._strip_cache_control(payload)
|
|
return await super()._acreate(payload)
|
|
|
|
def _generate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
|
|
"""Override with OAuth patching and retry logic."""
|
|
if self._is_oauth:
|
|
self._patch_client_oauth(self._client)
|
|
|
|
last_error = None
|
|
for attempt in range(1, self.retry_max_attempts + 1):
|
|
try:
|
|
return super()._generate(messages, stop=stop, **kwargs)
|
|
except anthropic.RateLimitError as e:
|
|
last_error = e
|
|
if attempt >= self.retry_max_attempts:
|
|
raise
|
|
wait_ms = self._calc_backoff_ms(attempt, e)
|
|
logger.warning(f"Rate limited, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
|
time.sleep(wait_ms / 1000)
|
|
except anthropic.InternalServerError as e:
|
|
last_error = e
|
|
if attempt >= self.retry_max_attempts:
|
|
raise
|
|
wait_ms = self._calc_backoff_ms(attempt, e)
|
|
logger.warning(f"Server error, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
|
time.sleep(wait_ms / 1000)
|
|
raise last_error
|
|
|
|
async def _agenerate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
|
|
"""Async override with OAuth patching and retry logic."""
|
|
import asyncio
|
|
|
|
if self._is_oauth:
|
|
self._patch_client_oauth(self._async_client)
|
|
|
|
last_error = None
|
|
for attempt in range(1, self.retry_max_attempts + 1):
|
|
try:
|
|
return await super()._agenerate(messages, stop=stop, **kwargs)
|
|
except anthropic.RateLimitError as e:
|
|
last_error = e
|
|
if attempt >= self.retry_max_attempts:
|
|
raise
|
|
wait_ms = self._calc_backoff_ms(attempt, e)
|
|
logger.warning(f"Rate limited, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
|
await asyncio.sleep(wait_ms / 1000)
|
|
except anthropic.InternalServerError as e:
|
|
last_error = e
|
|
if attempt >= self.retry_max_attempts:
|
|
raise
|
|
wait_ms = self._calc_backoff_ms(attempt, e)
|
|
logger.warning(f"Server error, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
|
await asyncio.sleep(wait_ms / 1000)
|
|
raise last_error
|
|
|
|
@staticmethod
|
|
def _calc_backoff_ms(attempt: int, error: Exception) -> int:
|
|
"""Exponential backoff with a fixed 20% buffer."""
|
|
backoff_ms = 2000 * (1 << (attempt - 1))
|
|
jitter_ms = int(backoff_ms * 0.2)
|
|
total_ms = backoff_ms + jitter_ms
|
|
|
|
if hasattr(error, "response") and error.response is not None:
|
|
retry_after = error.response.headers.get("Retry-After")
|
|
if retry_after:
|
|
try:
|
|
total_ms = int(retry_after) * 1000
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
return total_ms
|