Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en/advance/metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ lmdeploy serve api_server \
--enable-metrics
```

You should be able to see multiple API servers added to the proxy server list. Details can be found in `lmdeploy/serve/proxy/proxy_config.json`.
You should be able to see multiple API servers added to the proxy server list. Query `GET /nodes/status` on the proxy for the current replica set.

For example, you may have the following API servers:

Expand Down
5 changes: 4 additions & 1 deletion docs/en/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
sys.path.insert(0, os.path.abspath('../..'))

from lmdeploy.serve.openai.api_server import router # noqa: E402
from lmdeploy.serve.proxy.proxy import app as proxy_server # noqa: E402
from lmdeploy.serve.proxy.app import create_app # noqa: E402
from lmdeploy.serve.proxy.core.config import ProxyConfig # noqa: E402

proxy_server = create_app(ProxyConfig()) # noqa: E402

version_file = '../../lmdeploy/version.py'
with open(version_file) as f:
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/advance/metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ lmdeploy serve api_server \
--enable-metrics
```

您应该能在代理服务器列表中看到多个 API 服务实例。详细信息可以在 `lmdeploy/serve/proxy/proxy_config.json` 中找到
您应该能在代理服务器列表中看到多个 API 服务实例。可通过代理的 `GET /nodes/status` 查看当前副本列表

例如,您可能会看到如下 API 服务地址:

Expand Down
5 changes: 4 additions & 1 deletion docs/zh_cn/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
sys.path.insert(0, os.path.abspath('../..'))

from lmdeploy.serve.openai.api_server import router # noqa: E402
from lmdeploy.serve.proxy.proxy import app as proxy_server # noqa: E402
from lmdeploy.serve.proxy.app import create_app # noqa: E402
from lmdeploy.serve.proxy.core.config import ProxyConfig # noqa: E402

proxy_server = create_app(ProxyConfig()) # noqa: E402

version_file = '../../lmdeploy/version.py'
with open(version_file) as f:
Expand Down
8 changes: 1 addition & 7 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,6 @@ def add_parser_proxy():
choices=['random', 'min_expected_latency', 'min_observed_latency'],
default='min_expected_latency',
help='the strategy to dispatch requests to nodes')
parser.add_argument('--disable-cache-status',
action='store_true',
help='Whether to disable cache status of the '
'proxy. If set, the proxy will forget the status '
'of the previous time')

# For Disaggregation
parser.add_argument('--migration-protocol',
type=str,
Expand Down Expand Up @@ -349,7 +343,7 @@ def api_server(args):
@staticmethod
def proxy(args):
"""Proxy server that manages distributed api_server nodes."""
from lmdeploy.serve.proxy.proxy import proxy
from lmdeploy.serve.proxy.cli import proxy
kwargs = convert_args(args)
proxy(**kwargs)

Expand Down
65 changes: 52 additions & 13 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,31 +1299,70 @@ def dummy_get_device_id():
_set_func('mmengine.logging.logger._get_device_id', dummy_get_device_id)


@router.on_event('startup')
async def startup_event():
async_engine = VariableInterface.async_engine
async_engine.start_loop(asyncio.get_running_loop(), use_async_api=True)
async def _wait_until_listening(url: str, timeout: float = 60.0) -> bool:
"""Wait until the local HTTP server accepts connections."""
import requests

if VariableInterface.proxy_url is None:
health_url = f'{url}/health'
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
try:
await asyncio.to_thread(
requests.get,
health_url,
headers={'accept': 'application/json'},
timeout=2.0,
)
return True
except requests.exceptions.RequestException:
await asyncio.sleep(1)
return False


async def _register_with_proxy() -> None:
"""Register this api_server with the proxy after HTTP is listening."""
proxy_url = VariableInterface.proxy_url
api_server_url = VariableInterface.api_server_url
if proxy_url is None or api_server_url is None:
return
elif getattr(async_engine.engine, 'is_dummy', False):
logger.info('Dummy node started')
if not await _wait_until_listening(api_server_url):
logger.error(f'Service registration timed out waiting for {api_server_url} to listen')
return
try:
import requests

engine_config = VariableInterface.async_engine.backend_config
engine_role = engine_config.role.value if hasattr(engine_config, 'role') else 1
url = f'{VariableInterface.proxy_url}/nodes/add'
data = {'url': VariableInterface.api_server_url, 'status': {'models': get_model_list(), 'role': engine_role}}
url = f'{proxy_url}/nodes/add'
data = {
'url': api_server_url,
'status': {
'models': get_model_list(),
'role': engine_role,
},
}
headers = {'accept': 'application/json', 'Content-Type': 'application/json'}
response = requests.post(url, headers=headers, json=data)

if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail=response.text)
response = await asyncio.to_thread(requests.post, url, headers=headers, json=data, timeout=60)
if response.status_code != HTTPStatus.OK:
raise RuntimeError(f'HTTP {response.status_code}: {response.text}')
logger.info(f'Service registered with proxy: {api_server_url}')
except Exception as e:
logger.error(f'Service registration failed: {e}')


@router.on_event('startup')
async def startup_event():
async_engine = VariableInterface.async_engine
async_engine.start_loop(asyncio.get_running_loop(), use_async_api=True)

if VariableInterface.proxy_url is None:
return
elif getattr(async_engine.engine, 'is_dummy', False):
logger.info('Dummy node started')
return
asyncio.create_task(_register_with_proxy(), name='ProxyRegistration')


@router.on_event('shutdown')
async def shutdown_event():
async_engine = VariableInterface.async_engine
Expand Down
5 changes: 5 additions & 0 deletions lmdeploy/serve/proxy/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.

from lmdeploy.serve.proxy.app import create_app
from lmdeploy.serve.proxy.cli import proxy

__all__ = ['create_app', 'proxy']
77 changes: 77 additions & 0 deletions lmdeploy/serve/proxy/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) OpenMMLab. All rights reserved.

import os
from contextlib import asynccontextmanager

import aiohttp
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from lmdeploy.serve.proxy.core.config import ProxyConfig
from lmdeploy.serve.proxy.endpoint import admin, distserve, openai
from lmdeploy.serve.proxy.runtime import ProxyRuntime
from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')

_DEFAULT_AIOHTTP_LIMIT = 1024
_DEFAULT_AIOHTTP_LIMIT_PER_HOST = 128


def _read_env_int(name: str, default: int) -> int:
value = os.getenv(name)
if value is None or value == '':
return default
return int(value)


def _create_upstream_session() -> aiohttp.ClientSession:
"""Shared aiohttp session for forwarding to api_server replicas."""
timeout_value = os.getenv('AIOHTTP_TIMEOUT')
if timeout_value is None or timeout_value == '':
timeout = aiohttp.ClientTimeout(total=None)
else:
timeout = aiohttp.ClientTimeout(total=int(timeout_value))

connector = aiohttp.TCPConnector(
limit=_read_env_int('AIOHTTP_LIMIT', _DEFAULT_AIOHTTP_LIMIT),
limit_per_host=_read_env_int('AIOHTTP_LIMIT_PER_HOST', _DEFAULT_AIOHTTP_LIMIT_PER_HOST),
)
logger.info(
f'Proxy upstream aiohttp: timeout={timeout.total}, '
f'limit={connector.limit}, limit_per_host={connector.limit_per_host}. '
'Override via env AIOHTTP_TIMEOUT, AIOHTTP_LIMIT, AIOHTTP_LIMIT_PER_HOST.',
)
return aiohttp.ClientSession(timeout=timeout, connector=connector)


def create_app(config: ProxyConfig) -> FastAPI:
"""Build FastAPI application for the proxy server."""

@asynccontextmanager
async def lifespan(app: FastAPI):
async with _create_upstream_session() as session:
app.state.runtime = ProxyRuntime(config, session)
yield

app = FastAPI(docs_url='/', lifespan=lifespan)
app.state.proxy_config = config

app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)

if config.api_keys:
tokens = [key for key in config.api_keys if key]
if tokens:
app.add_middleware(AuthenticationMiddleware, tokens=tokens)

app.include_router(openai.router)
app.include_router(admin.router)
app.include_router(distserve.router)
return app
77 changes: 77 additions & 0 deletions lmdeploy/serve/proxy/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) OpenMMLab. All rights reserved.

import os
from typing import Literal

import uvicorn

from lmdeploy.pytorch.disagg.config import DistServeRDMAConfig, RDMALinkType, ServingStrategy
from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol
from lmdeploy.serve.proxy.app import create_app
from lmdeploy.serve.proxy.core.config import ProxyConfig, RoutingStrategy
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


def proxy(server_name: str = '0.0.0.0',
server_port: int = 8000,
serving_strategy: Literal['Hybrid', 'DistServe'] = 'Hybrid',
routing_strategy: Literal['random', 'min_expected_latency', 'min_observed_latency'] = 'min_expected_latency',
api_keys: list[str] | str | None = None,
ssl: bool = False,
log_level: str = 'INFO',
link_type: Literal['RoCE', 'IB'] = 'RoCE',
migration_protocol: Literal['RDMA', 'NVLINK'] = 'RDMA',
dummy_prefill: bool = False,
disable_gdr: bool = False,
**kwargs):
"""Launch the proxy server."""
keys: list[str] | None = None
if api_keys is not None:
if isinstance(api_keys, str):
keys = [api_keys] if api_keys else None
else:
keys = list(api_keys)

config = ProxyConfig(
serving_strategy=ServingStrategy[serving_strategy],
routing_strategy=RoutingStrategy.from_str(routing_strategy),
migration_protocol=MigrationProtocol[migration_protocol],
rdma_config=DistServeRDMAConfig(
link_type=RDMALinkType[link_type],
with_gdr=not disable_gdr,
),
dummy_prefill=dummy_prefill,
server_name=server_name,
server_port=server_port,
api_keys=keys,
ssl=ssl,
log_level=log_level,
)

app = create_app(config)

ssl_keyfile, ssl_certfile = None, None
if ssl:
ssl_keyfile = os.environ.get('SSL_KEYFILE')
ssl_certfile = os.environ.get('SSL_CERTFILE')
if not ssl_keyfile or not ssl_certfile:
raise ValueError('SSL is enabled but SSL_KEYFILE and SSL_CERTFILE must be set.')

logger.setLevel(log_level)
uvicorn_log_level = os.getenv('UVICORN_LOG_LEVEL', 'info').lower()
uvicorn.run(
app=app,
host=server_name,
port=server_port,
log_level=uvicorn_log_level,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
)


if __name__ == '__main__':
import fire

fire.Fire(proxy)
6 changes: 6 additions & 0 deletions lmdeploy/serve/proxy/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.

from .config import ProxyConfig, RoutingStrategy
from .replica import ReplicaLoad, ReplicaRegistration

__all__ = ['ProxyConfig', 'ReplicaLoad', 'ReplicaRegistration', 'RoutingStrategy']
43 changes: 43 additions & 0 deletions lmdeploy/serve/proxy/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.

import enum
from dataclasses import dataclass

from lmdeploy.pytorch.disagg.config import DistServeRDMAConfig, ServingStrategy
from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol


class RoutingStrategy(enum.Enum):
"""Strategy to dispatch requests to nodes."""

RANDOM = enum.auto()
MIN_EXPECTED_LATENCY = enum.auto()
MIN_OBSERVED_LATENCY = enum.auto()

@classmethod
def from_str(cls, name: str) -> 'RoutingStrategy':
"""Get strategy from string."""
if name == 'random':
return cls.RANDOM
if name == 'min_expected_latency':
return cls.MIN_EXPECTED_LATENCY
if name == 'min_observed_latency':
return cls.MIN_OBSERVED_LATENCY
raise ValueError(f'Invalid strategy: {name}. Supported: random, '
f'min_expected_latency, min_observed_latency.')


@dataclass
class ProxyConfig:
"""Runtime configuration for the proxy server."""

serving_strategy: ServingStrategy = ServingStrategy.Hybrid
routing_strategy: RoutingStrategy = RoutingStrategy.MIN_EXPECTED_LATENCY
migration_protocol: MigrationProtocol = MigrationProtocol.RDMA
rdma_config: DistServeRDMAConfig | None = None
dummy_prefill: bool = False
server_name: str = '0.0.0.0'
server_port: int = 8000
api_keys: list[str] | None = None
ssl: bool = False
log_level: str = 'INFO'
Comment on lines +4 to +43
Loading
Loading